Skip to content

Commit 91fe159

Browse files
committed
Add StarSetBicoloringResult and TreeSetBicoloringResult
1 parent ca44a88 commit 91fe159

4 files changed

Lines changed: 189 additions & 134 deletions

File tree

docs/src/dev.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@ SparseMatrixColorings.RowColoringResult
3939
SparseMatrixColorings.StarSetColoringResult
4040
SparseMatrixColorings.TreeSetColoringResult
4141
SparseMatrixColorings.LinearSystemColoringResult
42-
SparseMatrixColorings.BicoloringResult
42+
SparseMatrixColorings.StarSetBicoloringResult
43+
SparseMatrixColorings.TreeSetBicoloringResult
4344
```
4445

4546
## Decompression

src/decompression.jl

Lines changed: 56 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -528,41 +528,6 @@ function decompress!(
528528
return A
529529
end
530530

531-
function decompress!(
532-
A::SparseMatrixCSC,
533-
Br::AbstractMatrix,
534-
Bc::AbstractMatrix,
535-
symmetric_to_row::Vector{Int},
536-
symmetric_to_column::Vector{Int},
537-
result::StarSetColoringResult,
538-
)
539-
(; ag, color, compressed_indices) = result
540-
(; S) = ag
541-
n = size(Br, 2)
542-
m = size(Bc, 1)
543-
dim = m + n
544-
nzA = nonzeros(A)
545-
rvS = rowvals(S)
546-
l = 0 # assume A has the same pattern as the triangle
547-
for j in axes(S, 2)
548-
for k in nzrange(S, j)
549-
i = rvS[k]
550-
if in_triangle(i, j, :L)
551-
l += 1
552-
j2, i2 = divrem(compressed_indices[k] - 1, dim)
553-
j2 += 1
554-
i2 += 1
555-
if i2 n
556-
nzA[l] = Br[symmetric_to_row[j2], i2]
557-
else
558-
nzA[l] = Bc[i2 - n, symmetric_to_column[j2]]
559-
end
560-
end
561-
end
562-
end
563-
return A
564-
end
565-
566531
## TreeSetColoringResult
567532

568533
function decompress!(
@@ -719,74 +684,6 @@ function decompress!(
719684
return A
720685
end
721686

722-
function decompress!(
723-
A::SparseMatrixCSC{R},
724-
Br::AbstractMatrix{R},
725-
Bc::AbstractMatrix{R},
726-
symmetric_to_row::Vector{Int},
727-
symmetric_to_column::Vector{Int},
728-
result::TreeSetColoringResult,
729-
) where {R<:Real}
730-
(;
731-
ag,
732-
color,
733-
reverse_bfs_orders,
734-
diagonal_indices,
735-
diagonal_nzind,
736-
lower_triangle_offsets,
737-
upper_triangle_offsets,
738-
buffer,
739-
) = result
740-
(; S) = ag
741-
A_colptr = A.colptr
742-
nzA = nonzeros(A)
743-
m = size(Bc, 1)
744-
n = size(Br, 2)
745-
746-
if eltype(buffer) == R
747-
buffer_right_type = buffer
748-
else
749-
buffer_right_type = similar(buffer, R)
750-
end
751-
752-
# Index of offsets in lower_triangle_offsets and upper_triangle_offsets
753-
counter = 0
754-
755-
# Recover the off-diagonal coefficients of A
756-
for k in eachindex(reverse_bfs_orders)
757-
# Reset the buffer to zero for all vertices in a tree (except the root)
758-
for (vertex, _) in reverse_bfs_orders[k]
759-
buffer_right_type[vertex] = zero(R)
760-
end
761-
# Reset the buffer to zero for the root vertex
762-
(_, root) = reverse_bfs_orders[k][end]
763-
buffer_right_type[root] = zero(R)
764-
765-
for (i, j) in reverse_bfs_orders[k]
766-
counter += 1
767-
if i n
768-
val = Br[symmetric_to_row[color[j]], i] - buffer_right_type[i]
769-
else
770-
val = Bc[i - n, symmetric_to_column[color[j]]] - buffer_right_type[i]
771-
end
772-
buffer_right_type[j] = buffer_right_type[j] + val
773-
774-
#! format: off
775-
# A[i,j] is in the lower triangular part of A
776-
if in_triangle(i, j, :L)
777-
nzind = A_colptr[j + 1] - lower_triangle_offsets[counter]
778-
nzA[nzind] = val
779-
# A[i,j] is in the upper triangular part of A
780-
else
781-
nzind = A_colptr[i + 1] - lower_triangle_offsets[counter]
782-
nzA[nzind] = val
783-
end
784-
#! format: on
785-
end
786-
end
787-
return A
788-
end
789-
790687
## MatrixInverseColoringResult
791688

792689
function decompress!(
@@ -861,14 +758,59 @@ function Base.getindex(B::JoinCompressed, i::Int, j::Int)
861758
end
862759
end
863760

864-
function Base.getindex(B::JoinCompressed, k::Int)
865-
dim = B.m + B.n
866-
j, i = divrem(k - 1, dim)
867-
return getindex(B, i + 1, j + 1)
761+
## StarSetBicoloringResult
762+
763+
function decompress!(
764+
A::AbstractMatrix,
765+
Br::AbstractMatrix,
766+
Bc::AbstractMatrix,
767+
result::StarSetBicoloringResult,
768+
)
769+
(; ag, symmetric_color, symmetric_to_row, symmetric_to_column, star_set) = result
770+
(; star, hub, spokes) = star_set
771+
(; S) = ag
772+
fill!(A, zero(eltype(A)))
773+
774+
m, n = size(A)
775+
for s in eachindex(hub, spokes)
776+
j = abs(hub[s])
777+
cj = symmetric_color[j]
778+
for i in spokes[s]
779+
if in_triangle(i, j, :L)
780+
A[i - n, j] = Bc[i - n, symmetric_to_column[cj]]
781+
else
782+
A[j - n, i] = Br[symmetric_to_row[cj], i]
783+
end
784+
end
785+
end
786+
return A
787+
end
788+
789+
function decompress!(
790+
A::SparseMatrixCSC,
791+
Br::AbstractMatrix,
792+
Bc::AbstractMatrix,
793+
result::StarSetBicoloringResult,
794+
)
795+
(; ag, A_indices, compressed_indices, pos_Br) = result
796+
(; S) = ag
797+
nzA = nonzeros(A)
798+
for k in 1:pos_Br
799+
nzA[A_indices[k]] = Br[compressed_indices[k]]
800+
end
801+
for k in (pos_Br + 1):length(nzA)
802+
nzA[A_indices[k]] = Bc[compressed_indices[k]]
803+
end
804+
return A
868805
end
869806

807+
## TreeSetBicoloringResult
808+
870809
function decompress!(
871-
A::AbstractMatrix, Br::AbstractMatrix, Bc::AbstractMatrix, result::BicoloringResult
810+
A::AbstractMatrix,
811+
Br::AbstractMatrix,
812+
Bc::AbstractMatrix,
813+
result::TreeSetBicoloringResult,
872814
)
873815
(; symmetric_to_row, symmetric_to_column, symmetric_result) = result
874816
m, n = size(A)
@@ -879,15 +821,19 @@ function decompress!(
879821
end
880822

881823
function decompress!(
882-
A::SparseMatrixCSC, Br::AbstractMatrix, Bc::AbstractMatrix, result::BicoloringResult
824+
A::SparseMatrixCSC,
825+
Br::AbstractMatrix,
826+
Bc::AbstractMatrix,
827+
result::TreeSetBicoloringResult,
883828
)
884829
(;
885830
symmetric_to_row, symmetric_to_column, symmetric_result, large_colptr, large_rowval
886831
) = result
887832
m, n = size(A)
833+
Br_and_Bc = JoinCompressed(m, n, Br, Bc, symmetric_to_row, symmetric_to_column)
888834
# pretend A is larger
889835
A_and_noAᵀ = SparseMatrixCSC(m + n, m + n, large_colptr, large_rowval, A.nzval)
890836
# decompress lower triangle only
891-
decompress!(A_and_noAᵀ, Br, Bc, symmetric_to_row, symmetric_to_column, symmetric_result)
837+
decompress!(A_and_noAᵀ, Br_and_Bc, symmetric_result, :L)
892838
return A
893839
end

src/interface.jl

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -301,12 +301,15 @@ function _coloring(
301301
)
302302
A_and_Aᵀ = bidirectional_pattern(A; symmetric_pattern)
303303
ag = AdjacencyGraph(A_and_Aᵀ; has_diagonal=false)
304-
color, star_set = star_coloring(ag, algo.order; postprocessing=algo.postprocessing)
304+
symmetric_color, star_set = star_coloring(
305+
ag, algo.order; postprocessing=algo.postprocessing
306+
)
305307
if speed_setting isa WithResult
306-
symmetric_result = StarSetColoringResult(A_and_Aᵀ, ag, color, star_set)
307-
return BicoloringResult(A, ag, symmetric_result)
308+
return StarSetBicoloringResult(A, ag, symmetric_color, star_set)
308309
else
309-
row_color, column_color, _ = remap_colors(color, maximum(color), size(A)...)
310+
row_color, column_color, _ = remap_colors(
311+
symmetric_color, maximum(symmetric_color), size(A)...
312+
)
310313
return row_color, column_color
311314
end
312315
end
@@ -321,12 +324,16 @@ function _coloring(
321324
) where {R}
322325
A_and_Aᵀ = bidirectional_pattern(A; symmetric_pattern)
323326
ag = AdjacencyGraph(A_and_Aᵀ; has_diagonal=false)
324-
color, tree_set = acyclic_coloring(ag, algo.order; postprocessing=algo.postprocessing)
327+
symmetric_color, tree_set = acyclic_coloring(
328+
ag, algo.order; postprocessing=algo.postprocessing
329+
)
325330
if speed_setting isa WithResult
326-
symmetric_result = TreeSetColoringResult(A_and_Aᵀ, ag, color, tree_set, R)
327-
return BicoloringResult(A, ag, symmetric_result)
331+
symmetric_result = TreeSetColoringResult(A_and_Aᵀ, ag, symmetric_color, tree_set, R)
332+
return TreeSetBicoloringResult(A, ag, symmetric_result)
328333
else
329-
row_color, column_color, _ = remap_colors(color, maximum(color), size(A)...)
334+
row_color, column_color, _ = remap_colors(
335+
symmetric_color, maximum(symmetric_color), size(A)...
336+
)
330337
return row_color, column_color
331338
end
332339
end

0 commit comments

Comments
 (0)