diff --git a/Project.toml b/Project.toml index 157b96e6..c31b0402 100644 --- a/Project.toml +++ b/Project.toml @@ -5,7 +5,6 @@ version = "0.4.15" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" -DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" @@ -23,7 +22,6 @@ SparseMatrixColoringsColorsExt = "Colors" ADTypes = "1.2.1" CliqueTrees = "0.5.2" Colors = "0.12.11, 0.13" -DataStructures = "0.18" DocStringExtensions = "0.8,0.9" LinearAlgebra = "<0.0.1, 1" Random = "<0.0.1, 1" diff --git a/docs/src/dev.md b/docs/src/dev.md index 608d5650..9fff4e65 100644 --- a/docs/src/dev.md +++ b/docs/src/dev.md @@ -27,6 +27,7 @@ SparseMatrixColorings.symmetric_coefficient SparseMatrixColorings.star_coloring SparseMatrixColorings.acyclic_coloring SparseMatrixColorings.group_by_color +SparseMatrixColorings.Forest SparseMatrixColorings.StarSet SparseMatrixColorings.TreeSet ``` diff --git a/src/SparseMatrixColorings.jl b/src/SparseMatrixColorings.jl index 35b9857c..45cb7663 100644 --- a/src/SparseMatrixColorings.jl +++ b/src/SparseMatrixColorings.jl @@ -11,7 +11,6 @@ module SparseMatrixColorings using ADTypes: ADTypes using Base.Iterators: Iterators -using DataStructures: DisjointSets, find_root!, root_union!, num_groups using DocStringExtensions: README, EXPORTS, SIGNATURES, TYPEDEF, TYPEDFIELDS using LinearAlgebra: Adjoint, @@ -43,6 +42,7 @@ using SparseArrays: spzeros include("graph.jl") +include("forest.jl") include("order.jl") include("coloring.jl") include("result.jl") diff --git a/src/coloring.jl b/src/coloring.jl index 3246a299..c1d303f3 100644 --- a/src/coloring.jl +++ b/src/coloring.jl @@ -302,11 +302,7 @@ function acyclic_coloring(g::AdjacencyGraph, order::AbstractOrder; postprocessin forbidden_colors = zeros(Int, nv) first_neighbor = fill((0, 0), nv) # at first no neighbors have been encountered first_visit_to_tree = fill((0, 0), ne) - forest = DisjointSets{Tuple{Int,Int}}() - sizehint!(forest.intmap, ne) - sizehint!(forest.revmap, ne) - sizehint!(forest.internal.parents, ne) - sizehint!(forest.internal.ranks, ne) + forest = Forest{Int}(ne) vertices_in_order = vertices(g, order) for v in vertices_in_order @@ -346,10 +342,6 @@ function acyclic_coloring(g::AdjacencyGraph, order::AbstractOrder; postprocessin end end - # compress forest - for edge in forest.revmap - find_root!(forest, edge) - end tree_set = TreeSet(forest, nb_vertices(g)) if postprocessing # Reuse the vector forbidden_colors to compute offsets during post-processing @@ -367,11 +359,10 @@ function _prevent_cycle!( # modified first_visit_to_tree::AbstractVector{<:Tuple}, forbidden_colors::AbstractVector{<:Integer}, - forest::DisjointSets{<:Tuple{Int,Int}}, + forest::Forest{<:Integer}, ) wx = _sort(w, x) - root = find_root!(forest, wx) # edge wx belongs to the 2-colored tree T represented by edge "root" - id = forest.intmap[root] # ID of the representative edge "root" of a two-colored tree T. + id = find_root!(forest, wx) # The edge wx belongs to the 2-colored tree T, represented by an edge with an integer ID (p, q) = first_visit_to_tree[id] if p != v # T is being visited from vertex v for the first time vw = _sort(v, w) @@ -389,7 +380,7 @@ function _grow_star!( color::AbstractVector{<:Integer}, # modified first_neighbor::AbstractVector{<:Tuple}, - forest::DisjointSets{Tuple{Int,Int}}, + forest::Forest{<:Integer}, ) vw = _sort(v, w) push!(forest, vw) # Create a new tree T_{vw} consisting only of edge vw @@ -412,7 +403,7 @@ function _merge_trees!( w::Integer, x::Integer, # modified - forest::DisjointSets{Tuple{Int,Int}}, + forest::Forest{<:Integer}, ) vw = _sort(v, w) wx = _sort(w, x) @@ -438,27 +429,24 @@ struct TreeSet is_star::Vector{Bool} end -function TreeSet(forest::DisjointSets{Tuple{Int,Int}}, nvertices::Int) - # forest is a structure DisjointSets from DataStructures.jl +function TreeSet(forest::Forest{Int}, nvertices::Int) + # Forest is a structure defined in forest.jl # - forest.intmap: a dictionary that maps an edge (i, j) to an integer k - # - forest.revmap: a dictionary that does the reverse of intmap, mapping an integer k to an edge (i, j) - # - forest.internal.ngroups: the number of trees in the forest - ntrees = forest.internal.ngroups + # - forest.num_trees: the number of trees in the forest + nt = forest.num_trees # dictionary that maps a tree's root to the index of the tree roots = Dict{Int,Int}() - sizehint!(roots, ntrees) + sizehint!(roots, nt) # vector of dictionaries where each dictionary stores the neighbors of each vertex in a tree - trees = [Dict{Int,Vector{Int}}() for i in 1:ntrees] + trees = [Dict{Int,Vector{Int}}() for i in 1:nt] # counter of the number of roots found k = 0 - for edge in forest.revmap + for edge in keys(forest.intmap) i, j = edge - # forest has already been compressed so this doesn't change its state - root_edge = find_root!(forest, edge) - root = forest.intmap[root_edge] + root = find_root!(forest, edge) # Update roots if !haskey(roots, root) @@ -488,11 +476,11 @@ function TreeSet(forest::DisjointSets{Tuple{Int,Int}}, nvertices::Int) degrees = Vector{Int}(undef, nvertices) # reverse breadth first (BFS) traversal order for each tree in the forest - reverse_bfs_orders = [Tuple{Int,Int}[] for i in 1:ntrees] + reverse_bfs_orders = [Tuple{Int,Int}[] for i in 1:nt] # nvmax is the number of vertices of the biggest tree in the forest nvmax = 0 - for k in 1:ntrees + for k in 1:nt nb_vertices_tree = length(trees[k]) nvmax = max(nvmax, nb_vertices_tree) end @@ -502,9 +490,9 @@ function TreeSet(forest::DisjointSets{Tuple{Int,Int}}, nvertices::Int) # Specify if each tree in the forest is a star, # meaning that one vertex is directly connected to all other vertices in the tree - is_star = Vector{Bool}(undef, ntrees) + is_star = Vector{Bool}(undef, nt) - for k in 1:ntrees + for k in 1:nt tree = trees[k] # Boolean indicating whether the current tree is a star (a single central vertex connected to all others) diff --git a/src/forest.jl b/src/forest.jl new file mode 100644 index 00000000..71e8bb00 --- /dev/null +++ b/src/forest.jl @@ -0,0 +1,68 @@ +## Forest + +""" +$TYPEDEF + +Structure that provides fast union-find operations for constructing a forest during acyclic coloring and bicoloring. + +# Fields + +$TYPEDFIELDS +""" +mutable struct Forest{T<:Integer} + "current number of edges in the forest" + num_edges::T + "current number of distinct trees in the forest" + num_trees::T + "dictionary mapping each edge represented as a tuple of vertices to its unique integer index" + intmap::Dict{Tuple{T,T},T} + "vector storing the index of a parent in the tree for each edge, used in union-find operations" + parents::Vector{T} + "vector approximating the depth of each tree to optimize path compression" + ranks::Vector{T} +end + +function Forest{T}(n::Integer) where {T<:Integer} + num_edges = zero(T) + num_trees = zero(T) + intmap = Dict{Tuple{T,T},T}() + sizehint!(intmap, n) + parents = collect(Base.OneTo(T(n))) + ranks = zeros(T, T(n)) + return Forest{T}(num_edges, num_trees, intmap, parents, ranks) +end + +function Base.push!(forest::Forest{T}, edge::Tuple{T,T}) where {T<:Integer} + forest.num_edges += 1 + forest.intmap[edge] = forest.num_edges + forest.num_trees += one(T) + return forest +end + +function _find_root!(parents::Vector{T}, index_edge::T) where {T<:Integer} + p = parents[index_edge] + if parents[p] != p + parents[index_edge] = p = _find_root!(parents, p) + end + return p +end + +function find_root!(forest::Forest{T}, edge::Tuple{T,T}) where {T<:Integer} + return _find_root!(forest.parents, forest.intmap[edge]) +end + +function root_union!(forest::Forest{T}, index_edge1::T, index_edge2::T) where {T<:Integer} + parents = forest.parents + rks = forest.ranks + rank1 = rks[index_edge1] + rank2 = rks[index_edge2] + + if rank1 < rank2 + index_edge1, index_edge2 = index_edge2, index_edge1 + elseif rank1 == rank2 + rks[index_edge1] += one(T) + end + parents[index_edge2] = index_edge1 + forest.num_trees -= one(T) + return nothing +end diff --git a/test/forest.jl b/test/forest.jl new file mode 100644 index 00000000..dc5e6bf0 --- /dev/null +++ b/test/forest.jl @@ -0,0 +1,76 @@ +using SparseMatrixColorings: Forest, find_root!, root_union! +using Test + +@testset "Constructor Forest" begin + forest = Forest{Int}(5) + + @test forest.num_edges == 0 + @test forest.num_trees == 0 + @test length(forest.intmap) == 0 + @test length(forest.parents) == 5 + @test all(forest.parents .== 1:5) + @test all(forest.ranks .== 0) +end + +@testset "Push edge" begin + forest = Forest{Int}(5) + + push!(forest, (1, 2)) + @test forest.num_edges == 1 + @test forest.num_trees == 1 + @test haskey(forest.intmap, (1, 2)) + @test forest.intmap[(1, 2)] == 1 + @test forest.num_trees == 1 + + push!(forest, (3, 4)) + @test forest.num_edges == 2 + @test forest.num_trees == 2 + @test haskey(forest.intmap, (3, 4)) + @test forest.intmap[(3, 4)] == 2 + @test forest.num_trees == 2 +end + +@testset "Find root" begin + forest = Forest{Int}(5) + push!(forest, (1, 2)) + push!(forest, (3, 4)) + + @test find_root!(forest, (1, 2)) == 1 + @test find_root!(forest, (3, 4)) == 2 +end + +@testset "Root union" begin + forest = Forest{Int}(5) + push!(forest, (1, 2)) + push!(forest, (4, 5)) + push!(forest, (2, 4)) + @test forest.num_trees == 3 + + root1 = find_root!(forest, (1, 2)) + root3 = find_root!(forest, (2, 4)) + @test root1 != root3 + + root_union!(forest, root1, root3) + @test find_root!(forest, (2, 4)) == 1 + @test forest.parents[1] == 1 + @test forest.parents[3] == 1 + @test forest.ranks[1] == 1 + @test forest.ranks[3] == 0 + @test forest.num_trees == 2 + + root1 = find_root!(forest, (1, 2)) + root2 = find_root!(forest, (4, 5)) + @test root1 != root2 + root_union!(forest, root1, root2) + @test find_root!(forest, (4, 5)) == 1 + @test forest.parents[1] == 1 + @test forest.parents[2] == 1 + @test forest.ranks[1] == 1 + @test forest.ranks[2] == 0 + @test forest.num_trees == 1 + + push!(forest, (1, 4)) + @test forest.num_trees == 2 + @test forest.intmap[(1, 4)] == 4 + @test forest.parents[4] == 4 +end diff --git a/test/runtests.jl b/test/runtests.jl index a56cca5d..b37f65da 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -33,6 +33,9 @@ include("utils.jl") @testset "Graph" begin include("graph.jl") end + @testset "Forest" begin + include("forest.jl") + end @testset "Order" begin include("order.jl") end