Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ version = "0.4.15"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand All @@ -23,7 +22,6 @@ SparseMatrixColoringsColorsExt = "Colors"
ADTypes = "1.2.1"
CliqueTrees = "0.5.2"
Colors = "0.12.11, 0.13"
DataStructures = "0.18"
DocStringExtensions = "0.8,0.9"
LinearAlgebra = "<0.0.1, 1"
Random = "<0.0.1, 1"
Expand Down
1 change: 1 addition & 0 deletions docs/src/dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ SparseMatrixColorings.symmetric_coefficient
SparseMatrixColorings.star_coloring
SparseMatrixColorings.acyclic_coloring
SparseMatrixColorings.group_by_color
SparseMatrixColorings.Forest
SparseMatrixColorings.StarSet
SparseMatrixColorings.TreeSet
```
Expand Down
2 changes: 1 addition & 1 deletion src/SparseMatrixColorings.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ module SparseMatrixColorings

using ADTypes: ADTypes
using Base.Iterators: Iterators
using DataStructures: DisjointSets, find_root!, root_union!, num_groups
using DocStringExtensions: README, EXPORTS, SIGNATURES, TYPEDEF, TYPEDFIELDS
using LinearAlgebra:
Adjoint,
Expand Down Expand Up @@ -43,6 +42,7 @@ using SparseArrays:
spzeros

include("graph.jl")
include("forest.jl")
include("order.jl")
include("coloring.jl")
include("result.jl")
Expand Down
46 changes: 17 additions & 29 deletions src/coloring.jl
Original file line number Diff line number Diff line change
Expand Up @@ -302,11 +302,7 @@ function acyclic_coloring(g::AdjacencyGraph, order::AbstractOrder; postprocessin
forbidden_colors = zeros(Int, nv)
first_neighbor = fill((0, 0), nv) # at first no neighbors have been encountered
first_visit_to_tree = fill((0, 0), ne)
forest = DisjointSets{Tuple{Int,Int}}()
sizehint!(forest.intmap, ne)
sizehint!(forest.revmap, ne)
sizehint!(forest.internal.parents, ne)
sizehint!(forest.internal.ranks, ne)
forest = Forest{Int}(ne)
vertices_in_order = vertices(g, order)

for v in vertices_in_order
Expand Down Expand Up @@ -346,10 +342,6 @@ function acyclic_coloring(g::AdjacencyGraph, order::AbstractOrder; postprocessin
end
end

# compress forest
for edge in forest.revmap
find_root!(forest, edge)
end
tree_set = TreeSet(forest, nb_vertices(g))
if postprocessing
# Reuse the vector forbidden_colors to compute offsets during post-processing
Expand All @@ -367,11 +359,10 @@ function _prevent_cycle!(
# modified
first_visit_to_tree::AbstractVector{<:Tuple},
forbidden_colors::AbstractVector{<:Integer},
forest::DisjointSets{<:Tuple{Int,Int}},
forest::Forest{<:Integer},
)
wx = _sort(w, x)
root = find_root!(forest, wx) # edge wx belongs to the 2-colored tree T represented by edge "root"
id = forest.intmap[root] # ID of the representative edge "root" of a two-colored tree T.
id = find_root!(forest, wx) # The edge wx belongs to the 2-colored tree T, represented by an edge with an integer ID
(p, q) = first_visit_to_tree[id]
if p != v # T is being visited from vertex v for the first time
vw = _sort(v, w)
Expand All @@ -389,7 +380,7 @@ function _grow_star!(
color::AbstractVector{<:Integer},
# modified
first_neighbor::AbstractVector{<:Tuple},
forest::DisjointSets{Tuple{Int,Int}},
forest::Forest{<:Integer},
)
vw = _sort(v, w)
push!(forest, vw) # Create a new tree T_{vw} consisting only of edge vw
Expand All @@ -412,7 +403,7 @@ function _merge_trees!(
w::Integer,
x::Integer,
# modified
forest::DisjointSets{Tuple{Int,Int}},
forest::Forest{<:Integer},
)
vw = _sort(v, w)
wx = _sort(w, x)
Expand All @@ -438,27 +429,24 @@ struct TreeSet
is_star::Vector{Bool}
end

function TreeSet(forest::DisjointSets{Tuple{Int,Int}}, nvertices::Int)
# forest is a structure DisjointSets from DataStructures.jl
function TreeSet(forest::Forest{Int}, nvertices::Int)
# Forest is a structure defined in forest.jl
# - forest.intmap: a dictionary that maps an edge (i, j) to an integer k
# - forest.revmap: a dictionary that does the reverse of intmap, mapping an integer k to an edge (i, j)
# - forest.internal.ngroups: the number of trees in the forest
ntrees = forest.internal.ngroups
# - forest.num_trees: the number of trees in the forest
nt = forest.num_trees

# dictionary that maps a tree's root to the index of the tree
roots = Dict{Int,Int}()
sizehint!(roots, ntrees)
sizehint!(roots, nt)

# vector of dictionaries where each dictionary stores the neighbors of each vertex in a tree
trees = [Dict{Int,Vector{Int}}() for i in 1:ntrees]
trees = [Dict{Int,Vector{Int}}() for i in 1:nt]

# counter of the number of roots found
k = 0
for edge in forest.revmap
for edge in keys(forest.intmap)
i, j = edge
# forest has already been compressed so this doesn't change its state
root_edge = find_root!(forest, edge)
root = forest.intmap[root_edge]
root = find_root!(forest, edge)

# Update roots
if !haskey(roots, root)
Expand Down Expand Up @@ -488,11 +476,11 @@ function TreeSet(forest::DisjointSets{Tuple{Int,Int}}, nvertices::Int)
degrees = Vector{Int}(undef, nvertices)

# reverse breadth first (BFS) traversal order for each tree in the forest
reverse_bfs_orders = [Tuple{Int,Int}[] for i in 1:ntrees]
reverse_bfs_orders = [Tuple{Int,Int}[] for i in 1:nt]

# nvmax is the number of vertices of the biggest tree in the forest
nvmax = 0
for k in 1:ntrees
for k in 1:nt
nb_vertices_tree = length(trees[k])
nvmax = max(nvmax, nb_vertices_tree)
end
Expand All @@ -502,9 +490,9 @@ function TreeSet(forest::DisjointSets{Tuple{Int,Int}}, nvertices::Int)

# Specify if each tree in the forest is a star,
# meaning that one vertex is directly connected to all other vertices in the tree
is_star = Vector{Bool}(undef, ntrees)
is_star = Vector{Bool}(undef, nt)

for k in 1:ntrees
for k in 1:nt
tree = trees[k]

# Boolean indicating whether the current tree is a star (a single central vertex connected to all others)
Expand Down
68 changes: 68 additions & 0 deletions src/forest.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
## Forest

"""
$TYPEDEF

Structure that provides fast union-find operations for constructing a forest during acyclic coloring and bicoloring.

# Fields

$TYPEDFIELDS
"""
mutable struct Forest{T<:Integer}
"current number of edges in the forest"
num_edges::T
"current number of distinct trees in the forest"
num_trees::T
"dictionary mapping each edge represented as a tuple of vertices to its unique integer index"
intmap::Dict{Tuple{T,T},T}
"vector storing the index of a parent in the tree for each edge, used in union-find operations"
parents::Vector{T}
"vector approximating the depth of each tree to optimize path compression"
ranks::Vector{T}
end

function Forest{T}(n::Integer) where {T<:Integer}
num_edges = zero(T)
num_trees = zero(T)
intmap = Dict{Tuple{T,T},T}()
sizehint!(intmap, n)
parents = collect(Base.OneTo(T(n)))
ranks = zeros(T, T(n))
return Forest{T}(num_edges, num_trees, intmap, parents, ranks)
end

function Base.push!(forest::Forest{T}, edge::Tuple{T,T}) where {T<:Integer}
forest.num_edges += 1
forest.intmap[edge] = forest.num_edges
forest.num_trees += one(T)
return forest
end

function _find_root!(parents::Vector{T}, index_edge::T) where {T<:Integer}
p = parents[index_edge]
if parents[p] != p
parents[index_edge] = p = _find_root!(parents, p)
end
return p
end

function find_root!(forest::Forest{T}, edge::Tuple{T,T}) where {T<:Integer}
return _find_root!(forest.parents, forest.intmap[edge])
end

function root_union!(forest::Forest{T}, index_edge1::T, index_edge2::T) where {T<:Integer}
parents = forest.parents
rks = forest.ranks
rank1 = rks[index_edge1]
rank2 = rks[index_edge2]

if rank1 < rank2
index_edge1, index_edge2 = index_edge2, index_edge1
elseif rank1 == rank2
rks[index_edge1] += one(T)
end
parents[index_edge2] = index_edge1
forest.num_trees -= one(T)
return nothing
end
76 changes: 76 additions & 0 deletions test/forest.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
using SparseMatrixColorings: Forest, find_root!, root_union!
using Test

@testset "Constructor Forest" begin
forest = Forest{Int}(5)

@test forest.num_edges == 0
@test forest.num_trees == 0
@test length(forest.intmap) == 0
@test length(forest.parents) == 5
@test all(forest.parents .== 1:5)
@test all(forest.ranks .== 0)
end

@testset "Push edge" begin
forest = Forest{Int}(5)

push!(forest, (1, 2))
@test forest.num_edges == 1
@test forest.num_trees == 1
@test haskey(forest.intmap, (1, 2))
@test forest.intmap[(1, 2)] == 1
@test forest.num_trees == 1

push!(forest, (3, 4))
@test forest.num_edges == 2
@test forest.num_trees == 2
@test haskey(forest.intmap, (3, 4))
@test forest.intmap[(3, 4)] == 2
@test forest.num_trees == 2
end

@testset "Find root" begin
forest = Forest{Int}(5)
push!(forest, (1, 2))
push!(forest, (3, 4))

@test find_root!(forest, (1, 2)) == 1
@test find_root!(forest, (3, 4)) == 2
end

@testset "Root union" begin
forest = Forest{Int}(5)
push!(forest, (1, 2))
push!(forest, (4, 5))
push!(forest, (2, 4))
@test forest.num_trees == 3

root1 = find_root!(forest, (1, 2))
root3 = find_root!(forest, (2, 4))
@test root1 != root3

root_union!(forest, root1, root3)
@test find_root!(forest, (2, 4)) == 1
@test forest.parents[1] == 1
@test forest.parents[3] == 1
@test forest.ranks[1] == 1
@test forest.ranks[3] == 0
@test forest.num_trees == 2

root1 = find_root!(forest, (1, 2))
root2 = find_root!(forest, (4, 5))
@test root1 != root2
root_union!(forest, root1, root2)
@test find_root!(forest, (4, 5)) == 1
@test forest.parents[1] == 1
@test forest.parents[2] == 1
@test forest.ranks[1] == 1
@test forest.ranks[2] == 0
@test forest.num_trees == 1

push!(forest, (1, 4))
@test forest.num_trees == 2
@test forest.intmap[(1, 4)] == 4
@test forest.parents[4] == 4
end
3 changes: 3 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ include("utils.jl")
@testset "Graph" begin
include("graph.jl")
end
@testset "Forest" begin
include("forest.jl")
end
@testset "Order" begin
include("order.jl")
end
Expand Down
Loading