Skip to content

Commit 3968cfe

Browse files
amontoisongdalle
andauthored
Decompression of acyclic coloring with SparseMatrixCSC (#130)
* Decompression of acyclic coloring with SparseMatrixCSC * Decompression of acyclic coloring with SparseMatrixCSC * Decompression of acyclic coloring with SparseMatrixCSC * Fix the error in decompression.jl * Remove compat --------- Co-authored-by: Guillaume Dalle <22795598+gdalle@users.noreply.github.com>
1 parent 7146a7d commit 3968cfe

2 files changed

Lines changed: 171 additions & 3 deletions

File tree

src/decompression.jl

Lines changed: 100 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -460,8 +460,6 @@ end
460460

461461
## TreeSetColoringResult
462462

463-
# TODO: add method for A::SparseMatrixCSC
464-
465463
function decompress!(
466464
A::AbstractMatrix, B::AbstractMatrix, result::TreeSetColoringResult, uplo::Symbol=:F
467465
)
@@ -504,6 +502,106 @@ function decompress!(
504502
return A
505503
end
506504

505+
function decompress!(
506+
A::SparseMatrixCSC{R},
507+
B::AbstractMatrix{R},
508+
result::TreeSetColoringResult,
509+
uplo::Symbol=:F,
510+
) where {R<:Real}
511+
(;
512+
color,
513+
vertices_by_tree,
514+
reverse_bfs_orders,
515+
diagonal_indices,
516+
diagonal_nzind,
517+
lower_triangle_offsets,
518+
upper_triangle_offsets,
519+
buffer,
520+
) = result
521+
S = result.ag.S
522+
A_colptr = A.colptr
523+
nzA = nonzeros(A)
524+
uplo == :F && check_same_pattern(A, S)
525+
526+
if eltype(buffer) == R
527+
buffer_right_type = buffer
528+
else
529+
buffer_right_type = similar(buffer, R)
530+
end
531+
532+
# Recover the diagonal coefficients of A
533+
if uplo == :L
534+
for i in diagonal_indices
535+
# A[i, i] is the first element in column i
536+
nzind = A_colptr[i]
537+
nzA[nzind] = B[i, color[i]]
538+
end
539+
elseif uplo == :U
540+
for i in diagonal_indices
541+
# A[i, i] is the last element in column i
542+
nzind = A_colptr[i + 1] - 1
543+
nzA[nzind] = B[i, color[i]]
544+
end
545+
else # uplo == :F
546+
for (k, i) in enumerate(diagonal_indices)
547+
nzind = diagonal_nzind[k]
548+
nzA[nzind] = B[i, color[i]]
549+
end
550+
end
551+
552+
# Index of offsets in lower_triangle_offsets and upper_triangle_offsets
553+
counter = 0
554+
555+
# Recover the off-diagonal coefficients of A
556+
for k in eachindex(vertices_by_tree, reverse_bfs_orders)
557+
for vertex in vertices_by_tree[k]
558+
buffer_right_type[vertex] = zero(R)
559+
end
560+
561+
for (i, j) in reverse_bfs_orders[k]
562+
counter += 1
563+
val = B[i, color[j]] - buffer_right_type[i]
564+
buffer_right_type[j] = buffer_right_type[j] + val
565+
566+
#! format: off
567+
# A[i,j] is in the lower triangular part of A
568+
if in_triangle(i, j, :L)
569+
# uplo = :L or uplo = :F
570+
# A[i,j] is stored at index_ij = (A.colptr[j+1] - offset_L) in A.nzval
571+
if uplo != :U
572+
nzind = A_colptr[j + 1] - lower_triangle_offsets[counter]
573+
nzA[nzind] = val
574+
end
575+
576+
# uplo = :U or uplo = :F
577+
# A[j,i] is stored at index_ji = (A.colptr[i] + offset_U) in A.nzval
578+
if uplo != :L
579+
nzind = A_colptr[i] + upper_triangle_offsets[counter]
580+
nzA[nzind] = val
581+
end
582+
583+
# A[i,j] is in the upper triangular part of A
584+
else
585+
# uplo = :U or uplo = :F
586+
# A[i,j] is stored at index_ij = (A.colptr[j] + offset_U) in A.nzval
587+
if uplo != :L
588+
nzind = A_colptr[j] + upper_triangle_offsets[counter]
589+
nzA[nzind] = val
590+
end
591+
592+
# uplo = :L or uplo = :F
593+
# A[j,i] is stored at index_ji = (A.colptr[i+1] - offset_L) in A.nzval
594+
if uplo != :U
595+
nzind = A_colptr[i + 1] - lower_triangle_offsets[counter]
596+
nzA[nzind] = val
597+
end
598+
end
599+
#! format: on
600+
end
601+
end
602+
return A
603+
end
604+
507605
## MatrixInverseColoringResult
508606

509607
function decompress!(

src/result.jl

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,10 @@ struct TreeSetColoringResult{M<:AbstractMatrix,G<:AdjacencyGraph,V,R} <:
258258
group::V
259259
vertices_by_tree::Vector{Vector{Int}}
260260
reverse_bfs_orders::Vector{Vector{Tuple{Int,Int}}}
261+
diagonal_indices::Vector{Int}
262+
diagonal_nzind::Vector{Int}
263+
lower_triangle_offsets::Vector{Int}
264+
upper_triangle_offsets::Vector{Int}
261265
buffer::Vector{R}
262266
end
263267

@@ -272,6 +276,29 @@ function TreeSetColoringResult(
272276
nvertices = length(color)
273277
group = group_by_color(color)
274278

279+
# Vector for the decompression of the diagonal coefficients
280+
diagonal_indices = Int[]
281+
diagonal_nzind = Int[]
282+
ndiag = 0
283+
284+
n = size(S, 1)
285+
rv = rowvals(S)
286+
for j in axes(S, 2)
287+
for k in nzrange(S, j)
288+
i = rv[k]
289+
if i == j
290+
push!(diagonal_indices, i)
291+
push!(diagonal_nzind, k)
292+
ndiag += 1
293+
end
294+
end
295+
end
296+
297+
# Vectors for the decompression of the off-diagonal coefficients
298+
nedges = (nnz(S) - ndiag) ÷ 2
299+
lower_triangle_offsets = Vector{Int}(undef, nedges)
300+
upper_triangle_offsets = Vector{Int}(undef, nedges)
301+
275302
# forest is a structure DisjointSets from DataStructures.jl
276303
# - forest.intmap: a dictionary that maps an edge (i, j) to an integer k
277304
# - forest.revmap: a dictionary that does the reverse of intmap, mapping an integer k to an edge (i, j)
@@ -333,6 +360,9 @@ function TreeSetColoringResult(
333360
# Create a queue with a fixed size nvmax
334361
queue = Vector{Int}(undef, nvmax)
335362

363+
# Index in lower_triangle_offsets and upper_triangle_offsets
364+
index_offsets = 0
365+
336366
for k in 1:ntrees
337367
tree = trees[k]
338368

@@ -373,6 +403,36 @@ function TreeSetColoringResult(
373403
queue_end += 1
374404
queue[queue_end] = neighbor
375405
end
406+
407+
# Update lower_triangle_offsets and upper_triangle_offsets
408+
i = leaf
409+
j = neighbor
410+
col_i = view(rv, nzrange(S, i))
411+
col_j = view(rv, nzrange(S, j))
412+
index_offsets += 1
413+
414+
#! format: off
415+
# S[i,j] is in the lower triangular part of S
416+
if in_triangle(i, j, :L)
417+
# uplo = :L or uplo = :F
418+
# S[i,j] is stored at index_ij = (S.colptr[j+1] - offset_L) in S.nzval
419+
lower_triangle_offsets[index_offsets] = length(col_j) - searchsortedfirst(col_j, i) + 1
420+
421+
# uplo = :U or uplo = :F
422+
# S[j,i] is stored at index_ji = (S.colptr[i] + offset_U) in S.nzval
423+
upper_triangle_offsets[index_offsets] = searchsortedfirst(col_i, j)::Int - 1
424+
425+
# S[i,j] is in the upper triangular part of S
426+
else
427+
# uplo = :U or uplo = :F
428+
# S[i,j] is stored at index_ij = (S.colptr[j] + offset_U) in S.nzval
429+
upper_triangle_offsets[index_offsets] = searchsortedfirst(col_j, i)::Int - 1
430+
431+
# uplo = :L or uplo = :F
432+
# S[j,i] is stored at index_ji = (S.colptr[i+1] - offset_L) in S.nzval
433+
lower_triangle_offsets[index_offsets] = length(col_i) - searchsortedfirst(col_i, j) + 1
434+
end
435+
#! format: on
376436
end
377437
end
378438
end
@@ -383,7 +443,17 @@ function TreeSetColoringResult(
383443
buffer = Vector{R}(undef, nvertices)
384444

385445
return TreeSetColoringResult(
386-
A, ag, color, group, vertices_by_tree, reverse_bfs_orders, buffer
446+
A,
447+
ag,
448+
color,
449+
group,
450+
vertices_by_tree,
451+
reverse_bfs_orders,
452+
diagonal_indices,
453+
diagonal_nzind,
454+
lower_triangle_offsets,
455+
upper_triangle_offsets,
456+
buffer,
387457
)
388458
end
389459

0 commit comments

Comments
 (0)