Skip to content

Commit 6c098da

Browse files
committed
Add an option decompression_uplo for symmetric results
1 parent 662df63 commit 6c098da

3 files changed

Lines changed: 94 additions & 38 deletions

File tree

src/decompression.jl

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -453,12 +453,18 @@ function decompress!(
453453
uplo == :F && check_same_pattern(A, S)
454454
fill!(A, zero(eltype(A)))
455455

456+
l = 0
456457
rvS = rowvals(S)
457458
for j in axes(S, 2)
458459
for k in nzrange(S, j)
459460
i = rvS[k]
460461
if in_triangle(i, j, uplo)
461-
A[i, j] = B[compressed_indices[k]]
462+
if result.decompression_uplo == :F
463+
A[i, j] = B[compressed_indices[k]]
464+
else
465+
l += 1
466+
A[i, j] = B[compressed_indices[l]]
467+
end
462468
end
463469
end
464470
end
@@ -472,6 +478,7 @@ function decompress_single_color!(
472478
result::StarSetColoringResult,
473479
uplo::Symbol=:F,
474480
)
481+
@assert result.decompression_uplo == :F
475482
(; ag, compressed_indices, group) = result
476483
(; S) = ag
477484
uplo == :F && check_same_pattern(A, S)
@@ -508,12 +515,13 @@ function decompress!(
508515
(; ag, compressed_indices) = result
509516
(; S) = ag
510517
nzA = nonzeros(A)
511-
if uplo == :F
512-
check_same_pattern(A, S)
518+
if result.decompression_uplo == uplo
519+
uplo == :F && check_same_pattern(A, S)
513520
for k in eachindex(nzA, compressed_indices)
514521
nzA[k] = B[compressed_indices[k]]
515522
end
516523
else
524+
@assert result.decompression_uplo == :F
517525
rvS = rowvals(S)
518526
l = 0 # assume A has the same pattern as the triangle
519527
for j in axes(S, 2)

src/interface.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,9 @@ function _coloring(
345345
t -> maximum(t[3]) + maximum(t[4]), outputs_by_order
346346
) # can't use ncolors without computing the full result
347347
if speed_setting isa WithResult
348-
symmetric_result = StarSetColoringResult(A_and_Aᵀ, ag, color, star_set)
348+
symmetric_result = StarSetColoringResult(
349+
A_and_Aᵀ, ag, color, star_set; decompression_uplo=:L
350+
)
349351
return BicoloringResult(
350352
A,
351353
ag,
@@ -390,7 +392,9 @@ function _coloring(
390392
t -> maximum(t[3]) + maximum(t[4]), outputs_by_order
391393
) # can't use ncolors without computing the full result
392394
if speed_setting isa WithResult
393-
symmetric_result = TreeSetColoringResult(A_and_Aᵀ, ag, color, tree_set, R)
395+
symmetric_result = TreeSetColoringResult(
396+
A_and_Aᵀ, ag, color, tree_set, R; decompression_uplo=:L
397+
)
394398
return BicoloringResult(
395399
A,
396400
ag,

src/result.jl

Lines changed: 77 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -309,55 +309,74 @@ struct StarSetColoringResult{
309309
color::CT
310310
group::GT
311311
compressed_indices::VT
312+
decompression_uplo::Symbol
312313
additional_info::A
313314
end
314315

315316
function StarSetColoringResult(
316317
A::AbstractMatrix,
317318
ag::AdjacencyGraph{T},
318319
color::Vector{<:Integer},
319-
star_set::StarSet{<:Integer},
320+
star_set::StarSet{<:Integer};
321+
decompression_uplo::Symbol=:F,
320322
) where {T<:Integer}
321323
group = group_by_color(T, color)
322-
compressed_indices = star_csc_indices(ag, color, star_set)
323-
return StarSetColoringResult(A, ag, color, group, compressed_indices, nothing)
324+
compressed_indices = star_csc_indices(ag, color, star_set, decompression_uplo)
325+
return StarSetColoringResult(
326+
A, ag, color, group, compressed_indices, decompression_uplo, nothing
327+
)
324328
end
325329

326330
function star_csc_indices(
327-
ag::AdjacencyGraph{T}, color::Vector{<:Integer}, star_set
331+
ag::AdjacencyGraph{T},
332+
color::Vector{<:Integer},
333+
star_set::StarSet{<:Integer},
334+
decompression_uplo::Symbol,
328335
) where {T}
329336
(; star, hub) = star_set
330337
S = pattern(ag)
331338
edge_to_index = edge_indices(ag)
332339
n = S.n
333340
rvS = rowvals(S)
334-
compressed_indices = zeros(T, nnz(S)) # needs to be independent from the storage in the graph, in case the graph gets reused
341+
l = 0
342+
if augmented_graph(ag) && (decompression_uplo != :F)
343+
compressed_indices = zeros(T, nnz(S) ÷ 2)
344+
else
345+
compressed_indices = zeros(T, nnz(S)) # needs to be independent from the storage in the graph, in case the graph gets reused
346+
end
335347
for j in axes(S, 2)
336348
for k in nzrange(S, j)
337349
i = rvS[k]
338350
if i == j
339351
# diagonal coefficients
352+
l += 1
340353
c = color[i]
341-
compressed_indices[k] = (c - 1) * n + i
354+
compressed_indices[l] = (c - 1) * n + i
342355
else
343-
# off-diagonal coefficients
344-
index_ij = edge_to_index[k]
345-
s = star[index_ij]
346-
h = abs(hub[s])
347-
348-
# Assign the non-hub vertex (spoke) to the correct position in spokes
349-
if i == h
350-
# i is the hub and j is the spoke
351-
c = color[i]
352-
compressed_indices[k] = (c - 1) * n + j
353-
else # j == h
354-
# j is the hub and i is the spoke
355-
c = color[j]
356-
compressed_indices[k] = (c - 1) * n + i
356+
if in_triangle(i, j, decompression_uplo)
357+
# off-diagonal coefficients
358+
l += 1
359+
index_ij = edge_to_index[k]
360+
s = star[index_ij]
361+
h = abs(hub[s])
362+
363+
# Assign the non-hub vertex (spoke) to the correct position in spokes
364+
if i == h
365+
# i is the hub and j is the spoke
366+
c = color[i]
367+
compressed_indices[l] = (c - 1) * n + j
368+
else # j == h
369+
# j is the hub and i is the spoke
370+
c = color[j]
371+
compressed_indices[l] = (c - 1) * n + i
372+
end
357373
end
358374
end
359375
end
360376
end
377+
if !augmented_graph(ag) && (decompression_uplo != :F)
378+
resize!(compressed_indices, l)
379+
end
361380
return compressed_indices
362381
end
363382

@@ -391,43 +410,59 @@ struct TreeSetColoringResult{
391410
lower_triangle_offsets::Vector{T}
392411
upper_triangle_offsets::Vector{T}
393412
buffer::Vector{R}
413+
decompression_uplo::Symbol
394414
end
395415

396416
function TreeSetColoringResult(
397417
A::AbstractMatrix,
398418
ag::AdjacencyGraph{T},
399419
color::Vector{<:Integer},
400420
tree_set::TreeSet{<:Integer},
401-
decompression_eltype::Type{R},
421+
decompression_eltype::Type{R};
422+
decompression_uplo::Symbol=:F,
402423
) where {T<:Integer,R}
403424
(; reverse_bfs_orders, tree_edge_indices, nt) = tree_set
404425
(; S) = ag
405426
nvertices = length(color)
406427
group = group_by_color(T, color)
407428
rv = rowvals(S)
408429

409-
# Vector for the decompression of the diagonal coefficients
410-
diagonal_indices = T[]
411-
diagonal_nzind = T[]
412430
ndiag = 0
413-
414431
if !augmented_graph(ag)
415432
for j in axes(S, 2)
416433
for k in nzrange(S, j)
417434
i = rv[k]
418435
if i == j
419-
push!(diagonal_indices, i)
420-
push!(diagonal_nzind, k)
421436
ndiag += 1
422437
end
423438
end
424439
end
425440
end
426441

442+
# Vector for the decompression of the diagonal coefficients
443+
diagonal_indices = Vector{T}(undef, ndiag)
444+
diagonal_nzind = (decompression_uplo == :F) ? Vector{T}(undef, ndiag) : T[]
445+
446+
if !augmented_graph(ag)
447+
l = 0
448+
for j in axes(S, 2)
449+
for k in nzrange(S, j)
450+
i = rv[k]
451+
if i == j
452+
l += 1
453+
diagonal_indices[l] = i
454+
if decompression_uplo == :F
455+
diagonal_nzind[l] = k
456+
end
457+
end
458+
end
459+
end
460+
end
461+
427462
# Vectors for the decompression of the off-diagonal coefficients
428463
nedges = (nnz(S) - ndiag) ÷ 2
429-
lower_triangle_offsets = Vector{T}(undef, nedges)
430-
upper_triangle_offsets = Vector{T}(undef, nedges)
464+
lower_triangle_offsets = decompression_uplo == :U ? T[] : Vector{T}(undef, nedges)
465+
upper_triangle_offsets = decompression_uplo == :L ? T[] : Vector{T}(undef, nedges)
431466

432467
# Index in lower_triangle_offsets and upper_triangle_offsets
433468
index_offsets = 0
@@ -451,21 +486,29 @@ function TreeSetColoringResult(
451486
if in_triangle(i, j, :L)
452487
# uplo = :L or uplo = :F
453488
# S[i,j] is stored at index_ij = (S.colptr[j+1] - offset_L) in S.nzval
454-
lower_triangle_offsets[index_offsets] = length(col_j) - searchsortedfirst(col_j, i) + 1
489+
if decompression_uplo != :U
490+
lower_triangle_offsets[index_offsets] = length(col_j) - searchsortedfirst(col_j, i) + 1
491+
end
455492

456493
# uplo = :U or uplo = :F
457494
# S[j,i] is stored at index_ji = (S.colptr[i] + offset_U) in S.nzval
458-
upper_triangle_offsets[index_offsets] = searchsortedfirst(col_i, j)::Int - 1
495+
if decompression_uplo != :L
496+
upper_triangle_offsets[index_offsets] = searchsortedfirst(col_i, j)::Int - 1
497+
end
459498

460499
# S[i,j] is in the upper triangular part of S
461500
else
462501
# uplo = :U or uplo = :F
463502
# S[i,j] is stored at index_ij = (S.colptr[j] + offset_U) in S.nzval
464-
upper_triangle_offsets[index_offsets] = searchsortedfirst(col_j, i)::Int - 1
503+
if decompression_uplo != :L
504+
upper_triangle_offsets[index_offsets] = searchsortedfirst(col_j, i)::Int - 1
505+
end
465506

466507
# uplo = :L or uplo = :F
467508
# S[j,i] is stored at index_ji = (S.colptr[i+1] - offset_L) in S.nzval
468-
lower_triangle_offsets[index_offsets] = length(col_i) - searchsortedfirst(col_i, j) + 1
509+
if decompression_uplo != :U
510+
lower_triangle_offsets[index_offsets] = length(col_i) - searchsortedfirst(col_i, j) + 1
511+
end
469512
end
470513
#! format: on
471514
end
@@ -488,6 +531,7 @@ function TreeSetColoringResult(
488531
lower_triangle_offsets,
489532
upper_triangle_offsets,
490533
buffer,
534+
decompression_uplo,
491535
)
492536
end
493537

0 commit comments

Comments
 (0)