Skip to content

Commit 1e8127c

Browse files
amontoisongdalle
andauthored
Implement our own structure Forest (#190)
* Implement our own structure Forest * Update src/forest.jl * Don't compress the forest before we create a TreeSet * Apply suggestions from code review Co-authored-by: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> * Add a docstring for the structure Forest * Fix the docstring of Forest * Update src/forest.jl * Add unit tests for Forest * Fix test/forest.jl --------- Co-authored-by: Guillaume Dalle <22795598+gdalle@users.noreply.github.com>
1 parent 848a6f3 commit 1e8127c

7 files changed

Lines changed: 166 additions & 32 deletions

File tree

Project.toml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ version = "0.4.15"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
8-
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
98
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
109
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1110
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
@@ -23,7 +22,6 @@ SparseMatrixColoringsColorsExt = "Colors"
2322
ADTypes = "1.2.1"
2423
CliqueTrees = "0.5.2"
2524
Colors = "0.12.11, 0.13"
26-
DataStructures = "0.18"
2725
DocStringExtensions = "0.8,0.9"
2826
LinearAlgebra = "<0.0.1, 1"
2927
Random = "<0.0.1, 1"

docs/src/dev.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ SparseMatrixColorings.symmetric_coefficient
2727
SparseMatrixColorings.star_coloring
2828
SparseMatrixColorings.acyclic_coloring
2929
SparseMatrixColorings.group_by_color
30+
SparseMatrixColorings.Forest
3031
SparseMatrixColorings.StarSet
3132
SparseMatrixColorings.TreeSet
3233
```

src/SparseMatrixColorings.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ module SparseMatrixColorings
1111

1212
using ADTypes: ADTypes
1313
using Base.Iterators: Iterators
14-
using DataStructures: DisjointSets, find_root!, root_union!, num_groups
1514
using DocStringExtensions: README, EXPORTS, SIGNATURES, TYPEDEF, TYPEDFIELDS
1615
using LinearAlgebra:
1716
Adjoint,
@@ -43,6 +42,7 @@ using SparseArrays:
4342
spzeros
4443

4544
include("graph.jl")
45+
include("forest.jl")
4646
include("order.jl")
4747
include("coloring.jl")
4848
include("result.jl")

src/coloring.jl

Lines changed: 17 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -302,11 +302,7 @@ function acyclic_coloring(g::AdjacencyGraph, order::AbstractOrder; postprocessin
302302
forbidden_colors = zeros(Int, nv)
303303
first_neighbor = fill((0, 0), nv) # at first no neighbors have been encountered
304304
first_visit_to_tree = fill((0, 0), ne)
305-
forest = DisjointSets{Tuple{Int,Int}}()
306-
sizehint!(forest.intmap, ne)
307-
sizehint!(forest.revmap, ne)
308-
sizehint!(forest.internal.parents, ne)
309-
sizehint!(forest.internal.ranks, ne)
305+
forest = Forest{Int}(ne)
310306
vertices_in_order = vertices(g, order)
311307

312308
for v in vertices_in_order
@@ -346,10 +342,6 @@ function acyclic_coloring(g::AdjacencyGraph, order::AbstractOrder; postprocessin
346342
end
347343
end
348344

349-
# compress forest
350-
for edge in forest.revmap
351-
find_root!(forest, edge)
352-
end
353345
tree_set = TreeSet(forest, nb_vertices(g))
354346
if postprocessing
355347
# Reuse the vector forbidden_colors to compute offsets during post-processing
@@ -367,11 +359,10 @@ function _prevent_cycle!(
367359
# modified
368360
first_visit_to_tree::AbstractVector{<:Tuple},
369361
forbidden_colors::AbstractVector{<:Integer},
370-
forest::DisjointSets{<:Tuple{Int,Int}},
362+
forest::Forest{<:Integer},
371363
)
372364
wx = _sort(w, x)
373-
root = find_root!(forest, wx) # edge wx belongs to the 2-colored tree T represented by edge "root"
374-
id = forest.intmap[root] # ID of the representative edge "root" of a two-colored tree T.
365+
id = find_root!(forest, wx) # The edge wx belongs to the 2-colored tree T, represented by an edge with an integer ID
375366
(p, q) = first_visit_to_tree[id]
376367
if p != v # T is being visited from vertex v for the first time
377368
vw = _sort(v, w)
@@ -389,7 +380,7 @@ function _grow_star!(
389380
color::AbstractVector{<:Integer},
390381
# modified
391382
first_neighbor::AbstractVector{<:Tuple},
392-
forest::DisjointSets{Tuple{Int,Int}},
383+
forest::Forest{<:Integer},
393384
)
394385
vw = _sort(v, w)
395386
push!(forest, vw) # Create a new tree T_{vw} consisting only of edge vw
@@ -412,7 +403,7 @@ function _merge_trees!(
412403
w::Integer,
413404
x::Integer,
414405
# modified
415-
forest::DisjointSets{Tuple{Int,Int}},
406+
forest::Forest{<:Integer},
416407
)
417408
vw = _sort(v, w)
418409
wx = _sort(w, x)
@@ -438,27 +429,24 @@ struct TreeSet
438429
is_star::Vector{Bool}
439430
end
440431

441-
function TreeSet(forest::DisjointSets{Tuple{Int,Int}}, nvertices::Int)
442-
# forest is a structure DisjointSets from DataStructures.jl
432+
function TreeSet(forest::Forest{Int}, nvertices::Int)
433+
# Forest is a structure defined in forest.jl
443434
# - forest.intmap: a dictionary that maps an edge (i, j) to an integer k
444-
# - forest.revmap: a dictionary that does the reverse of intmap, mapping an integer k to an edge (i, j)
445-
# - forest.internal.ngroups: the number of trees in the forest
446-
ntrees = forest.internal.ngroups
435+
# - forest.num_trees: the number of trees in the forest
436+
nt = forest.num_trees
447437

448438
# dictionary that maps a tree's root to the index of the tree
449439
roots = Dict{Int,Int}()
450-
sizehint!(roots, ntrees)
440+
sizehint!(roots, nt)
451441

452442
# vector of dictionaries where each dictionary stores the neighbors of each vertex in a tree
453-
trees = [Dict{Int,Vector{Int}}() for i in 1:ntrees]
443+
trees = [Dict{Int,Vector{Int}}() for i in 1:nt]
454444

455445
# counter of the number of roots found
456446
k = 0
457-
for edge in forest.revmap
447+
for edge in keys(forest.intmap)
458448
i, j = edge
459-
# forest has already been compressed so this doesn't change its state
460-
root_edge = find_root!(forest, edge)
461-
root = forest.intmap[root_edge]
449+
root = find_root!(forest, edge)
462450

463451
# Update roots
464452
if !haskey(roots, root)
@@ -488,11 +476,11 @@ function TreeSet(forest::DisjointSets{Tuple{Int,Int}}, nvertices::Int)
488476
degrees = Vector{Int}(undef, nvertices)
489477

490478
# reverse breadth first (BFS) traversal order for each tree in the forest
491-
reverse_bfs_orders = [Tuple{Int,Int}[] for i in 1:ntrees]
479+
reverse_bfs_orders = [Tuple{Int,Int}[] for i in 1:nt]
492480

493481
# nvmax is the number of vertices of the biggest tree in the forest
494482
nvmax = 0
495-
for k in 1:ntrees
483+
for k in 1:nt
496484
nb_vertices_tree = length(trees[k])
497485
nvmax = max(nvmax, nb_vertices_tree)
498486
end
@@ -502,9 +490,9 @@ function TreeSet(forest::DisjointSets{Tuple{Int,Int}}, nvertices::Int)
502490

503491
# Specify if each tree in the forest is a star,
504492
# meaning that one vertex is directly connected to all other vertices in the tree
505-
is_star = Vector{Bool}(undef, ntrees)
493+
is_star = Vector{Bool}(undef, nt)
506494

507-
for k in 1:ntrees
495+
for k in 1:nt
508496
tree = trees[k]
509497

510498
# Boolean indicating whether the current tree is a star (a single central vertex connected to all others)

src/forest.jl

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
## Forest
2+
3+
"""
4+
$TYPEDEF
5+
6+
Structure that provides fast union-find operations for constructing a forest during acyclic coloring and bicoloring.
7+
8+
# Fields
9+
10+
$TYPEDFIELDS
11+
"""
12+
mutable struct Forest{T<:Integer}
13+
"current number of edges in the forest"
14+
num_edges::T
15+
"current number of distinct trees in the forest"
16+
num_trees::T
17+
"dictionary mapping each edge represented as a tuple of vertices to its unique integer index"
18+
intmap::Dict{Tuple{T,T},T}
19+
"vector storing the index of a parent in the tree for each edge, used in union-find operations"
20+
parents::Vector{T}
21+
"vector approximating the depth of each tree to optimize path compression"
22+
ranks::Vector{T}
23+
end
24+
25+
function Forest{T}(n::Integer) where {T<:Integer}
26+
num_edges = zero(T)
27+
num_trees = zero(T)
28+
intmap = Dict{Tuple{T,T},T}()
29+
sizehint!(intmap, n)
30+
parents = collect(Base.OneTo(T(n)))
31+
ranks = zeros(T, T(n))
32+
return Forest{T}(num_edges, num_trees, intmap, parents, ranks)
33+
end
34+
35+
function Base.push!(forest::Forest{T}, edge::Tuple{T,T}) where {T<:Integer}
36+
forest.num_edges += 1
37+
forest.intmap[edge] = forest.num_edges
38+
forest.num_trees += one(T)
39+
return forest
40+
end
41+
42+
function _find_root!(parents::Vector{T}, index_edge::T) where {T<:Integer}
43+
p = parents[index_edge]
44+
if parents[p] != p
45+
parents[index_edge] = p = _find_root!(parents, p)
46+
end
47+
return p
48+
end
49+
50+
function find_root!(forest::Forest{T}, edge::Tuple{T,T}) where {T<:Integer}
51+
return _find_root!(forest.parents, forest.intmap[edge])
52+
end
53+
54+
function root_union!(forest::Forest{T}, index_edge1::T, index_edge2::T) where {T<:Integer}
55+
parents = forest.parents
56+
rks = forest.ranks
57+
rank1 = rks[index_edge1]
58+
rank2 = rks[index_edge2]
59+
60+
if rank1 < rank2
61+
index_edge1, index_edge2 = index_edge2, index_edge1
62+
elseif rank1 == rank2
63+
rks[index_edge1] += one(T)
64+
end
65+
parents[index_edge2] = index_edge1
66+
forest.num_trees -= one(T)
67+
return nothing
68+
end

test/forest.jl

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
using SparseMatrixColorings: Forest, find_root!, root_union!
2+
using Test
3+
4+
@testset "Constructor Forest" begin
5+
forest = Forest{Int}(5)
6+
7+
@test forest.num_edges == 0
8+
@test forest.num_trees == 0
9+
@test length(forest.intmap) == 0
10+
@test length(forest.parents) == 5
11+
@test all(forest.parents .== 1:5)
12+
@test all(forest.ranks .== 0)
13+
end
14+
15+
@testset "Push edge" begin
16+
forest = Forest{Int}(5)
17+
18+
push!(forest, (1, 2))
19+
@test forest.num_edges == 1
20+
@test forest.num_trees == 1
21+
@test haskey(forest.intmap, (1, 2))
22+
@test forest.intmap[(1, 2)] == 1
23+
@test forest.num_trees == 1
24+
25+
push!(forest, (3, 4))
26+
@test forest.num_edges == 2
27+
@test forest.num_trees == 2
28+
@test haskey(forest.intmap, (3, 4))
29+
@test forest.intmap[(3, 4)] == 2
30+
@test forest.num_trees == 2
31+
end
32+
33+
@testset "Find root" begin
34+
forest = Forest{Int}(5)
35+
push!(forest, (1, 2))
36+
push!(forest, (3, 4))
37+
38+
@test find_root!(forest, (1, 2)) == 1
39+
@test find_root!(forest, (3, 4)) == 2
40+
end
41+
42+
@testset "Root union" begin
43+
forest = Forest{Int}(5)
44+
push!(forest, (1, 2))
45+
push!(forest, (4, 5))
46+
push!(forest, (2, 4))
47+
@test forest.num_trees == 3
48+
49+
root1 = find_root!(forest, (1, 2))
50+
root3 = find_root!(forest, (2, 4))
51+
@test root1 != root3
52+
53+
root_union!(forest, root1, root3)
54+
@test find_root!(forest, (2, 4)) == 1
55+
@test forest.parents[1] == 1
56+
@test forest.parents[3] == 1
57+
@test forest.ranks[1] == 1
58+
@test forest.ranks[3] == 0
59+
@test forest.num_trees == 2
60+
61+
root1 = find_root!(forest, (1, 2))
62+
root2 = find_root!(forest, (4, 5))
63+
@test root1 != root2
64+
root_union!(forest, root1, root2)
65+
@test find_root!(forest, (4, 5)) == 1
66+
@test forest.parents[1] == 1
67+
@test forest.parents[2] == 1
68+
@test forest.ranks[1] == 1
69+
@test forest.ranks[2] == 0
70+
@test forest.num_trees == 1
71+
72+
push!(forest, (1, 4))
73+
@test forest.num_trees == 2
74+
@test forest.intmap[(1, 4)] == 4
75+
@test forest.parents[4] == 4
76+
end

test/runtests.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@ include("utils.jl")
3131
@testset "Graph" begin
3232
include("graph.jl")
3333
end
34+
@testset "Forest" begin
35+
include("forest.jl")
36+
end
3437
@testset "Order" begin
3538
include("order.jl")
3639
end

0 commit comments

Comments
 (0)