Skip to content

Commit 44e75b9

Browse files
committed
Add an option decompression_uplo for symmetric results
1 parent 31999f1 commit 44e75b9

4 files changed

Lines changed: 75 additions & 38 deletions

File tree

ext/SparseMatrixColoringsCUDAExt.jl

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,13 +47,15 @@ function SMC.StarSetColoringResult(
4747
A::CuSparseMatrixCSC,
4848
ag::SMC.AdjacencyGraph{T},
4949
color::Vector{<:Integer},
50-
star_set::SMC.StarSet{<:Integer},
50+
star_set::SMC.StarSet{<:Integer};
51+
decompression_uplo::Symbol=:F,
5152
) where {T<:Integer}
53+
@assert decompression_uplo == :F
5254
group = SMC.group_by_color(T, color)
53-
compressed_indices = SMC.star_csc_indices(ag, color, star_set)
55+
compressed_indices = SMC.star_csc_indices(ag, color, star_set, decompression_uplo)
5456
additional_info = (; compressed_indices_gpu_csc=CuVector(compressed_indices))
5557
return SMC.StarSetColoringResult(
56-
A, ag, color, group, compressed_indices, additional_info
58+
A, ag, color, group, compressed_indices, decompression_uplo, additional_info
5759
)
5860
end
5961

@@ -85,13 +87,15 @@ function SMC.StarSetColoringResult(
8587
A::CuSparseMatrixCSR,
8688
ag::SMC.AdjacencyGraph{T},
8789
color::Vector{<:Integer},
88-
star_set::SMC.StarSet{<:Integer},
90+
star_set::SMC.StarSet{<:Integer};
91+
decompression_uplo::Symbol=:F,
8992
) where {T<:Integer}
93+
@assert decompression_uplo == :F
9094
group = SMC.group_by_color(T, color)
91-
compressed_indices = SMC.star_csc_indices(ag, color, star_set)
95+
compressed_indices = SMC.star_csc_indices(ag, color, star_set, decompression_uplo)
9296
additional_info = (; compressed_indices_gpu_csr=CuVector(compressed_indices))
9397
return SMC.StarSetColoringResult(
94-
A, ag, color, group, compressed_indices, additional_info
98+
A, ag, color, group, compressed_indices, decompression_uplo, additional_info
9599
)
96100
end
97101

src/decompression.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -448,6 +448,7 @@ end
448448
function decompress!(
449449
A::AbstractMatrix, B::AbstractMatrix, result::StarSetColoringResult, uplo::Symbol=:F
450450
)
451+
@assert result.decompression_uplo == :F
451452
(; ag, compressed_indices) = result
452453
(; S) = ag
453454
check_compatible_pattern(A, ag, uplo)
@@ -472,6 +473,7 @@ function decompress_single_color!(
472473
result::StarSetColoringResult,
473474
uplo::Symbol=:F,
474475
)
476+
@assert result.decompression_uplo == :F
475477
(; ag, compressed_indices, group) = result
476478
(; S) = ag
477479
check_compatible_pattern(A, ag, uplo)
@@ -509,11 +511,12 @@ function decompress!(
509511
(; S) = ag
510512
nzA = nonzeros(A)
511513
check_compatible_pattern(A, ag, uplo)
512-
if uplo == :F
514+
if result.decompression_uplo == uplo
513515
for k in eachindex(nzA, compressed_indices)
514516
nzA[k] = B[compressed_indices[k]]
515517
end
516518
else
519+
@assert result.decompression_uplo == :F
517520
rvS = rowvals(S)
518521
l = 0 # assume A has the same pattern as the triangle
519522
for j in axes(S, 2)
@@ -534,6 +537,7 @@ end
534537
function decompress!(
535538
A::AbstractMatrix, B::AbstractMatrix, result::TreeSetColoringResult, uplo::Symbol=:F
536539
)
540+
@assert result.decompression_uplo == :F
537541
(; ag, color, reverse_bfs_orders, tree_edge_indices, nt, diagonal_indices, buffer) =
538542
result
539543
(; S) = ag

src/interface.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,7 @@ function _coloring(
286286
end
287287
color, star_set = argmin(maximum first, color_and_star_set_by_order)
288288
if speed_setting isa WithResult
289-
return StarSetColoringResult(A, ag, color, star_set)
289+
return StarSetColoringResult(A, ag, color, star_set, :F)
290290
else
291291
return color
292292
end
@@ -307,7 +307,7 @@ function _coloring(
307307
end
308308
color, tree_set = argmin(maximum first, color_and_tree_set_by_order)
309309
if speed_setting isa WithResult
310-
return TreeSetColoringResult(A, ag, color, tree_set, R)
310+
return TreeSetColoringResult(A, ag, color, tree_set, R, :F)
311311
else
312312
return color
313313
end
@@ -345,7 +345,7 @@ function _coloring(
345345
t -> maximum(t[3]) + maximum(t[4]), outputs_by_order
346346
) # can't use ncolors without computing the full result
347347
if speed_setting isa WithResult
348-
symmetric_result = StarSetColoringResult(A_and_Aᵀ, ag, color, star_set)
348+
symmetric_result = StarSetColoringResult(A_and_Aᵀ, ag, color, star_set, :L)
349349
return BicoloringResult(
350350
A,
351351
ag,
@@ -390,7 +390,7 @@ function _coloring(
390390
t -> maximum(t[3]) + maximum(t[4]), outputs_by_order
391391
) # can't use ncolors without computing the full result
392392
if speed_setting isa WithResult
393-
symmetric_result = TreeSetColoringResult(A_and_Aᵀ, ag, color, tree_set, R)
393+
symmetric_result = TreeSetColoringResult(A_and_Aᵀ, ag, color, tree_set, R, :L)
394394
return BicoloringResult(
395395
A,
396396
ag,

src/result.jl

Lines changed: 56 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,7 @@ struct StarSetColoringResult{
309309
color::CT
310310
group::GT
311311
compressed_indices::VT
312+
decompression_uplo::Symbol
312313
additional_info::A
313314
end
314315

@@ -317,43 +318,58 @@ function StarSetColoringResult(
317318
ag::AdjacencyGraph{T},
318319
color::Vector{<:Integer},
319320
star_set::StarSet{<:Integer},
321+
decompression_uplo::Symbol,
320322
) where {T<:Integer}
321323
group = group_by_color(T, color)
322-
compressed_indices = star_csc_indices(ag, color, star_set)
323-
return StarSetColoringResult(A, ag, color, group, compressed_indices, nothing)
324+
compressed_indices = star_csc_indices(ag, color, star_set, decompression_uplo)
325+
return StarSetColoringResult(
326+
A, ag, color, group, compressed_indices, decompression_uplo, nothing
327+
)
324328
end
325329

326330
function star_csc_indices(
327-
ag::AdjacencyGraph{T}, color::Vector{<:Integer}, star_set
331+
ag::AdjacencyGraph{T},
332+
color::Vector{<:Integer},
333+
star_set::StarSet{<:Integer},
334+
decompression_uplo::Symbol,
328335
) where {T}
329336
(; star, hub) = star_set
330337
S = pattern(ag)
331338
edge_to_index = edge_indices(ag)
332339
n = S.n
333340
rvS = rowvals(S)
334-
compressed_indices = zeros(T, nnz(S)) # needs to be independent from the storage in the graph, in case the graph gets reused
341+
nb_indices = nnz(S)
342+
if decompression_uplo != :F
343+
nb_indices = nb_edges(ag) + ag.nb_self_loops
344+
end
345+
compressed_indices = zeros(T, nb_indices) # needs to be independent from the storage in the graph, in case the graph gets reused
346+
l = 0
335347
for j in axes(S, 2)
336348
for k in nzrange(S, j)
337349
i = rvS[k]
338350
if i == j
339351
# diagonal coefficients
352+
l += 1
340353
c = color[i]
341-
compressed_indices[k] = (c - 1) * n + i
354+
compressed_indices[l] = (c - 1) * n + i
342355
else
343-
# off-diagonal coefficients
344-
index_ij = edge_to_index[k]
345-
s = star[index_ij]
346-
h = abs(hub[s])
347-
348-
# Assign the non-hub vertex (spoke) to the correct position in spokes
349-
if i == h
350-
# i is the hub and j is the spoke
351-
c = color[i]
352-
compressed_indices[k] = (c - 1) * n + j
353-
else # j == h
354-
# j is the hub and i is the spoke
355-
c = color[j]
356-
compressed_indices[k] = (c - 1) * n + i
356+
if in_triangle(i, j, decompression_uplo)
357+
# off-diagonal coefficients
358+
l += 1
359+
index_ij = edge_to_index[k]
360+
s = star[index_ij]
361+
h = abs(hub[s])
362+
363+
# Assign the non-hub vertex (spoke) to the correct position in spokes
364+
if i == h
365+
# i is the hub and j is the spoke
366+
c = color[i]
367+
compressed_indices[l] = (c - 1) * n + j
368+
else # j == h
369+
# j is the hub and i is the spoke
370+
c = color[j]
371+
compressed_indices[l] = (c - 1) * n + i
372+
end
357373
end
358374
end
359375
end
@@ -391,6 +407,7 @@ struct TreeSetColoringResult{
391407
lower_triangle_offsets::Vector{T}
392408
upper_triangle_offsets::Vector{T}
393409
buffer::Vector{R}
410+
decompression_uplo::Symbol
394411
end
395412

396413
function TreeSetColoringResult(
@@ -399,6 +416,7 @@ function TreeSetColoringResult(
399416
color::Vector{<:Integer},
400417
tree_set::TreeSet{<:Integer},
401418
decompression_eltype::Type{R},
419+
decompression_uplo::Symbol,
402420
) where {T<:Integer,R}
403421
(; reverse_bfs_orders, tree_edge_indices, nt) = tree_set
404422
(; S, nb_self_loops) = ag
@@ -408,7 +426,7 @@ function TreeSetColoringResult(
408426

409427
# Vector for the decompression of the diagonal coefficients
410428
diagonal_indices = Vector{T}(undef, nb_self_loops)
411-
diagonal_nzind = Vector{T}(undef, nb_self_loops)
429+
diagonal_nzind = (decompression_uplo == :F) ? Vector{T}(undef, nb_self_loops) : T[]
412430

413431
if !augmented_graph(ag)
414432
l = 0
@@ -418,16 +436,18 @@ function TreeSetColoringResult(
418436
if i == j
419437
l += 1
420438
diagonal_indices[l] = i
421-
diagonal_nzind[l] = k
439+
if decompression_uplo == :F
440+
diagonal_nzind[l] = k
441+
end
422442
end
423443
end
424444
end
425445
end
426446

427447
# Vectors for the decompression of the off-diagonal coefficients
428448
nedges = nb_edges(ag)
429-
lower_triangle_offsets = Vector{T}(undef, nedges)
430-
upper_triangle_offsets = Vector{T}(undef, nedges)
449+
lower_triangle_offsets = decompression_uplo == :U ? T[] : Vector{T}(undef, nedges)
450+
upper_triangle_offsets = decompression_uplo == :L ? T[] : Vector{T}(undef, nedges)
431451

432452
# Index in lower_triangle_offsets and upper_triangle_offsets
433453
index_offsets = 0
@@ -451,21 +471,29 @@ function TreeSetColoringResult(
451471
if in_triangle(i, j, :L)
452472
# uplo = :L or uplo = :F
453473
# S[i,j] is stored at index_ij = (S.colptr[j+1] - offset_L) in S.nzval
454-
lower_triangle_offsets[index_offsets] = length(col_j) - searchsortedfirst(col_j, i) + 1
474+
if decompression_uplo != :U
475+
lower_triangle_offsets[index_offsets] = length(col_j) - searchsortedfirst(col_j, i) + 1
476+
end
455477

456478
# uplo = :U or uplo = :F
457479
# S[j,i] is stored at index_ji = (S.colptr[i] + offset_U) in S.nzval
458-
upper_triangle_offsets[index_offsets] = searchsortedfirst(col_i, j)::Int - 1
480+
if decompression_uplo != :L
481+
upper_triangle_offsets[index_offsets] = searchsortedfirst(col_i, j)::Int - 1
482+
end
459483

460484
# S[i,j] is in the upper triangular part of S
461485
else
462486
# uplo = :U or uplo = :F
463487
# S[i,j] is stored at index_ij = (S.colptr[j] + offset_U) in S.nzval
464-
upper_triangle_offsets[index_offsets] = searchsortedfirst(col_j, i)::Int - 1
488+
if decompression_uplo != :L
489+
upper_triangle_offsets[index_offsets] = searchsortedfirst(col_j, i)::Int - 1
490+
end
465491

466492
# uplo = :L or uplo = :F
467493
# S[j,i] is stored at index_ji = (S.colptr[i+1] - offset_L) in S.nzval
468-
lower_triangle_offsets[index_offsets] = length(col_i) - searchsortedfirst(col_i, j) + 1
494+
if decompression_uplo != :U
495+
lower_triangle_offsets[index_offsets] = length(col_i) - searchsortedfirst(col_i, j) + 1
496+
end
469497
end
470498
#! format: on
471499
end
@@ -488,6 +516,7 @@ function TreeSetColoringResult(
488516
lower_triangle_offsets,
489517
upper_triangle_offsets,
490518
buffer,
519+
decompression_uplo,
491520
)
492521
end
493522

0 commit comments

Comments
 (0)