Skip to content

Commit 5572dab

Browse files
amontoisongdalle
andauthored
Replace num_edges_per_tree by tree_edge_indices in TreeSet (#237)
* Replace num_edges_per_tree by tree_edge_indices in TreeSet * Update two comments --------- Co-authored-by: Guillaume Dalle <22795598+gdalle@users.noreply.github.com>
1 parent b20ac32 commit 5572dab

5 files changed

Lines changed: 63 additions & 56 deletions

File tree

src/coloring.jl

Lines changed: 36 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -376,7 +376,8 @@ $TYPEDFIELDS
376376
struct TreeSet{T}
377377
reverse_bfs_orders::Vector{Tuple{T,T}}
378378
is_star::Vector{Bool}
379-
num_edges_per_tree::Vector{T}
379+
tree_edge_indices::Vector{T}
380+
nt::T
380381
end
381382

382383
function TreeSet(
@@ -388,20 +389,20 @@ function TreeSet(
388389
S = pattern(g)
389390
edge_to_index = edge_indices(g)
390391
nv = nb_vertices(g)
391-
nt = forest.num_trees
392+
(; nt, ranks) = forest
392393

393394
# root_to_tree is a vector that maps a tree's root to the index of the tree
394-
# We can recycle forest.ranks because we don't need it anymore to merge trees
395-
root_to_tree = forest.ranks
395+
# We can recycle the vector "ranks" because we don't need it anymore to merge trees
396+
root_to_tree = ranks
396397
fill!(root_to_tree, zero(T))
397398

398-
# Contains the number of edges per tree
399-
num_edges_per_tree = zeros(T, nt)
399+
# vector specifying the starting and ending indices of edges for each tree
400+
tree_edge_indices = zeros(T, nt + 1)
400401

401402
# vector of dictionaries where each dictionary stores the neighbors of each vertex in a tree
402403
trees = [Dict{T,Vector{T}}() for i in 1:nt]
403404

404-
# current number of roots found
405+
# number of roots found
405406
nr = 0
406407

407408
rvS = rowvals(S)
@@ -418,18 +419,20 @@ function TreeSet(
418419
root_to_tree[root] = nr
419420
end
420421

421-
# index of the tree T that contains this edge
422+
# index of the tree that contains this edge
422423
index_tree = root_to_tree[root]
423-
num_edges_per_tree[index_tree] += 1
424424

425-
# Update the neighbors of i in the tree T
425+
# Update the number of edges for the current tree (shifted by 1 to facilitate the final cumsum)
426+
tree_edge_indices[index_tree + 1] += 1
427+
428+
# Update the neighbors of i in the current tree
426429
if !haskey(trees[index_tree], i)
427430
trees[index_tree][i] = [j]
428431
else
429432
push!(trees[index_tree][i], j)
430433
end
431434

432-
# Update the neighbors of j in the tree T
435+
# Update the neighbors of j in the current tree
433436
if !haskey(trees[index_tree], j)
434437
trees[index_tree][j] = [i]
435438
else
@@ -439,6 +442,12 @@ function TreeSet(
439442
end
440443
end
441444

445+
# Compute a shifted cumulative sum of tree_edge_indices, starting from one
446+
tree_edge_indices[1] = one(T)
447+
for k in 2:(nt + 1)
448+
tree_edge_indices[k] += tree_edge_indices[k - 1]
449+
end
450+
442451
# degrees is a vector of integers that stores the degree of each vertex in a tree
443452
degrees = buffer
444453

@@ -529,7 +538,7 @@ function TreeSet(
529538
is_star[k] = bool_star
530539
end
531540

532-
return TreeSet(reverse_bfs_orders, is_star, num_edges_per_tree)
541+
return TreeSet(reverse_bfs_orders, is_star, tree_edge_indices, nt)
533542
end
534543

535544
## Postprocessing, mirrors decompression code
@@ -597,15 +606,17 @@ function postprocess!(
597606
end
598607
else
599608
# only the colors of non-leaf vertices are used
600-
(; reverse_bfs_orders, is_star, num_edges_per_tree) = star_or_tree_set
609+
(; reverse_bfs_orders, is_star, tree_edge_indices, nt) = star_or_tree_set
601610
nb_trivial_trees = 0
602611

603-
# Index of the first edge in reverse_bfs_orders for the current tree
604-
first = 1
605-
606612
# Iterate through all non-trivial trees
607-
for k in eachindex(num_edges_per_tree)
608-
ne_tree = num_edges_per_tree[k]
613+
for k in 1:nt
614+
# Position of the first edge in the tree
615+
first = tree_edge_indices[k]
616+
617+
# Total number of edges in the tree
618+
ne_tree = tree_edge_indices[k + 1] - first
619+
609620
# Check if we have more than one edge in the tree (non-trivial tree)
610621
if ne_tree > 1
611622
# Determine if the tree is a star
@@ -622,14 +633,17 @@ function postprocess!(
622633
else
623634
nb_trivial_trees += 1
624635
end
625-
first += ne_tree
626636
end
627637

628638
# Process the trivial trees (if any)
629639
if nb_trivial_trees > 0
630-
first = 1
631-
for k in eachindex(num_edges_per_tree)
632-
ne_tree = num_edges_per_tree[k]
640+
for k in 1:nt
641+
# Position of the first edge in the tree
642+
first = tree_edge_indices[k]
643+
644+
# Total number of edges in the tree
645+
ne_tree = tree_edge_indices[k + 1] - first
646+
633647
# Check if we have exactly one edge in the tree
634648
if ne_tree == 1
635649
(i, j) = reverse_bfs_orders[first]
@@ -642,7 +656,6 @@ function postprocess!(
642656
color_used[color[j]] = true
643657
end
644658
end
645-
first += ne_tree
646659
end
647660
end
648661
end

src/decompression.jl

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -517,7 +517,7 @@ end
517517
function decompress!(
518518
A::AbstractMatrix, B::AbstractMatrix, result::TreeSetColoringResult, uplo::Symbol=:F
519519
)
520-
(; ag, color, reverse_bfs_orders, num_edges_per_tree, buffer) = result
520+
(; ag, color, reverse_bfs_orders, tree_edge_indices, nt, buffer) = result
521521
(; S) = ag
522522
uplo == :F && check_same_pattern(A, S)
523523
R = eltype(A)
@@ -538,13 +538,11 @@ function decompress!(
538538
end
539539
end
540540

541-
# Index of the first edge in reverse_bfs_orders for the current tree
542-
first = 1
543-
544541
# Recover the off-diagonal coefficients of A
545-
for k in eachindex(num_edges_per_tree)
546-
ne_tree = num_edges_per_tree[k]
547-
last = first + ne_tree - 1
542+
for k in 1:nt
543+
# Positions of the edges for each tree
544+
first = tree_edge_indices[k]
545+
last = tree_edge_indices[k + 1] - 1
548546

549547
# Reset the buffer to zero for all vertices in a tree (except the root)
550548
for pos in first:last
@@ -567,7 +565,6 @@ function decompress!(
567565
A[j, i] = val
568566
end
569567
end
570-
first += ne_tree
571568
end
572569
return A
573570
end
@@ -582,7 +579,8 @@ function decompress!(
582579
ag,
583580
color,
584581
reverse_bfs_orders,
585-
num_edges_per_tree,
582+
tree_edge_indices,
583+
nt,
586584
diagonal_indices,
587585
diagonal_nzind,
588586
lower_triangle_offsets,
@@ -622,16 +620,14 @@ function decompress!(
622620
end
623621
end
624622

625-
# Index of the first edge in reverse_bfs_orders for the current tree
626-
first = 1
627-
628623
# Index of offsets in lower_triangle_offsets and upper_triangle_offsets
629624
counter = 0
630625

631626
# Recover the off-diagonal coefficients of A
632-
for k in eachindex(num_edges_per_tree)
633-
ne_tree = num_edges_per_tree[k]
634-
last = first + ne_tree - 1
627+
for k in 1:nt
628+
# Positions of the edges for each tree
629+
first = tree_edge_indices[k]
630+
last = tree_edge_indices[k + 1] - 1
635631

636632
# Reset the buffer to zero for all vertices in a tree (except the root)
637633
for pos in first:last
@@ -683,7 +679,6 @@ function decompress!(
683679
end
684680
#! format: on
685681
end
686-
first += ne_tree
687682
end
688683
return A
689684
end

src/forest.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,18 @@ $TYPEDFIELDS
1111
"""
1212
mutable struct Forest{T<:Integer}
1313
"current number of distinct trees in the forest"
14-
num_trees::T
14+
nt::T
1515
"vector storing the index of a parent in the tree for each edge, used in union-find operations"
1616
parents::Vector{T}
1717
"vector approximating the depth of each tree to optimize path compression"
1818
ranks::Vector{T}
1919
end
2020

2121
function Forest{T}(n::Integer) where {T<:Integer}
22-
num_trees = T(n)
22+
nt = T(n)
2323
parents = collect(Base.OneTo(T(n)))
2424
ranks = zeros(T, T(n))
25-
return Forest{T}(num_trees, parents, ranks)
25+
return Forest{T}(nt, parents, ranks)
2626
end
2727

2828
function _find_root!(parents::Vector{T}, index_edge::T) where {T<:Integer}
@@ -49,6 +49,6 @@ function root_union!(forest::Forest{T}, index_edge1::T, index_edge2::T) where {T
4949
rks[index_edge1] += one(T)
5050
end
5151
parents[index_edge2] = index_edge1
52-
forest.num_trees -= one(T)
52+
forest.nt -= one(T)
5353
return nothing
5454
end

src/result.jl

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,8 @@ struct TreeSetColoringResult{
314314
color::Vector{T}
315315
group::GT
316316
reverse_bfs_orders::Vector{Tuple{T,T}}
317-
num_edges_per_tree::Vector{T}
317+
tree_edge_indices::Vector{T}
318+
nt::T
318319
diagonal_indices::Vector{T}
319320
diagonal_nzind::Vector{T}
320321
lower_triangle_offsets::Vector{T}
@@ -329,7 +330,7 @@ function TreeSetColoringResult(
329330
tree_set::TreeSet{<:Integer},
330331
decompression_eltype::Type{R},
331332
) where {T<:Integer,R}
332-
(; reverse_bfs_orders, num_edges_per_tree) = tree_set
333+
(; reverse_bfs_orders, tree_edge_indices, nt) = tree_set
333334
(; S) = ag
334335
nvertices = length(color)
335336
group = group_by_color(T, color)
@@ -358,15 +359,13 @@ function TreeSetColoringResult(
358359
lower_triangle_offsets = Vector{T}(undef, nedges)
359360
upper_triangle_offsets = Vector{T}(undef, nedges)
360361

361-
# Index of the first edge in reverse_bfs_orders for the current tree
362-
first = 1
363-
364362
# Index in lower_triangle_offsets and upper_triangle_offsets
365363
index_offsets = 0
366364

367-
for k in eachindex(num_edges_per_tree)
368-
ne_tree = num_edges_per_tree[k]
369-
last = first + ne_tree - 1
365+
for k in 1:nt
366+
# Positions of the edges for each tree
367+
first = tree_edge_indices[k]
368+
last = tree_edge_indices[k + 1] - 1
370369

371370
for pos in first:last
372371
(leaf, neighbor) = reverse_bfs_orders[pos]
@@ -400,7 +399,6 @@ function TreeSetColoringResult(
400399
end
401400
#! format: on
402401
end
403-
first += ne_tree
404402
end
405403

406404
# buffer holds the sum of edge values for subtrees in a tree.
@@ -413,7 +411,8 @@ function TreeSetColoringResult(
413411
color,
414412
group,
415413
reverse_bfs_orders,
416-
num_edges_per_tree,
414+
tree_edge_indices,
415+
nt,
417416
diagonal_indices,
418417
diagonal_nzind,
419418
lower_triangle_offsets,

test/forest.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ using Test
33

44
@testset "Constructor Forest" begin
55
forest = Forest{Int}(5)
6-
@test forest.num_trees == 5
6+
@test forest.nt == 5
77
@test length(forest.parents) == 5
88
@test all(forest.parents .== 1:5)
99
@test all(forest.ranks .== 0)
@@ -27,7 +27,7 @@ end
2727
@test forest.parents[3] == 1
2828
@test forest.ranks[1] == 1
2929
@test forest.ranks[3] == 0
30-
@test forest.num_trees == 4
30+
@test forest.nt == 4
3131

3232
root1 = find_root!(forest, 1)
3333
root2 = find_root!(forest, 2)
@@ -39,5 +39,5 @@ end
3939
@test forest.parents[2] == 1
4040
@test forest.ranks[1] == 1
4141
@test forest.ranks[2] == 0
42-
@test forest.num_trees == 3
42+
@test forest.nt == 3
4343
end

0 commit comments

Comments
 (0)