Skip to content

Commit dadb75e

Browse files
committed
Make any coloring algorithm compatible with SMC
1 parent 7efc6bc commit dadb75e

4 files changed

Lines changed: 158 additions & 68 deletions

File tree

src/adtypes.jl

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,32 @@
1-
function coloring(
2-
A::AbstractMatrix,
3-
::ColoringProblem{:nonsymmetric,:column},
4-
algo::ADTypes.NoColoringAlgorithm;
5-
kwargs...,
6-
)
7-
bg = BipartiteGraph(A)
8-
color = convert(Vector{eltype(bg)}, ADTypes.column_coloring(A, algo))
9-
return ColumnColoringResult(A, bg, color)
10-
end
1+
## From ADTypes to SMC
112

123
function coloring(
134
A::AbstractMatrix,
14-
::ColoringProblem{:nonsymmetric,:row},
15-
algo::ADTypes.NoColoringAlgorithm;
16-
kwargs...,
17-
)
18-
bg = BipartiteGraph(A)
19-
color = convert(Vector{eltype(bg)}, ADTypes.row_coloring(A, algo))
20-
return RowColoringResult(A, bg, color)
5+
problem::ColoringProblem{structure,partition},
6+
algo::ADTypes.AbstractColoringAlgorithm;
7+
decompression_eltype::Type{R}=Float64,
8+
symmetric_pattern::Bool=false,
9+
) where {structure,partition,R}
10+
symmetric_pattern = symmetric_pattern || A isa Union{Symmetric,Hermitian}
11+
if structure == :nonsymmetric
12+
if partition == :column
13+
forced_colors = ADTypes.column_coloring(A, algo)
14+
elseif partition == :row
15+
forced_colors = ADTypes.row_coloring(A, algo)
16+
else
17+
A_and_Aᵀ, _ = bidirectional_pattern(A; symmetric_pattern)
18+
forced_colors = ADTypes.symmetric_coloring(A_and_Aᵀ, algo)
19+
end
20+
else
21+
forced_colors = ADTypes.symmetric_coloring(A, algo)
22+
end
23+
return _coloring(
24+
WithResult(),
25+
A,
26+
problem,
27+
GreedyColoringAlgorithm(),
28+
R,
29+
symmetric_pattern;
30+
forced_colors,
31+
)
2132
end

src/coloring.jl

Lines changed: 62 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,17 @@
11
"""
2-
partial_distance2_coloring(bg::BipartiteGraph, ::Val{side}, vertices_in_order::AbstractVector)
2+
partial_distance2_coloring(
3+
bg::BipartiteGraph, ::Val{side}, vertices_in_order::AbstractVector;
4+
forced_colors::Union{AbstractVector{<:Integer},Nothing}=nothing
5+
)
36
47
Compute a distance-2 coloring of the given `side` (`1` or `2`) in the bipartite graph `bg` and return a vector of integer colors.
58
69
A _distance-2 coloring_ is such that two vertices have different colors if they are at distance at most 2.
710
811
The vertices are colored in a greedy fashion, following the order supplied.
912
13+
The optional `forced_colors` keyword argument is used to enforce predefined vertex colors (e.g. coming from another optimization algorithm) but still run the distance-2 coloring procedure to verify correctness.
14+
1015
# See also
1116
1217
- [`BipartiteGraph`](@ref)
@@ -17,11 +22,16 @@ The vertices are colored in a greedy fashion, following the order supplied.
1722
> [_What Color Is Your Jacobian? Graph Coloring for Computing Derivatives_](https://epubs.siam.org/doi/10.1137/S0036144504444711), Gebremedhin et al. (2005), Algorithm 3.2
1823
"""
1924
function partial_distance2_coloring(
20-
bg::BipartiteGraph{T}, ::Val{side}, vertices_in_order::AbstractVector{<:Integer}
25+
bg::BipartiteGraph{T},
26+
::Val{side},
27+
vertices_in_order::AbstractVector{<:Integer};
28+
forced_colors::Union{AbstractVector{<:Integer},Nothing}=nothing,
2129
) where {T,side}
2230
color = Vector{T}(undef, nb_vertices(bg, Val(side)))
2331
forbidden_colors = Vector{T}(undef, nb_vertices(bg, Val(side)))
24-
partial_distance2_coloring!(color, forbidden_colors, bg, Val(side), vertices_in_order)
32+
partial_distance2_coloring!(
33+
color, forbidden_colors, bg, Val(side), vertices_in_order; forced_colors
34+
)
2535
return color
2636
end
2737

@@ -30,7 +40,8 @@ function partial_distance2_coloring!(
3040
forbidden_colors::AbstractVector{<:Integer},
3141
bg::BipartiteGraph,
3242
::Val{side},
33-
vertices_in_order::AbstractVector{<:Integer},
43+
vertices_in_order::AbstractVector{<:Integer};
44+
forced_colors::Union{AbstractVector{<:Integer},Nothing}=nothing,
3445
) where {side}
3546
color .= 0
3647
forbidden_colors .= 0
@@ -44,17 +55,25 @@ function partial_distance2_coloring!(
4455
end
4556
end
4657
end
47-
for i in eachindex(forbidden_colors)
48-
if forbidden_colors[i] != v
49-
color[v] = i
50-
break
58+
if isnothing(forced_colors)
59+
for i in eachindex(forbidden_colors)
60+
if forbidden_colors[i] != v
61+
color[v] = i
62+
break
63+
end
5164
end
65+
else
66+
@assert forbidden_colors[forced_colors[v]] != v
67+
color[v] = forced_colors[v]
5268
end
5369
end
5470
end
5571

5672
"""
57-
star_coloring(g::AdjacencyGraph, vertices_in_order::AbstractVector, postprocessing::Bool)
73+
star_coloring(
74+
g::AdjacencyGraph, vertices_in_order::AbstractVector, postprocessing::Bool;
75+
forced_colors::Union{AbstractVector,Nothing}=nothing
76+
)
5877
5978
Compute a star coloring of all vertices in the adjacency graph `g` and return a tuple `(color, star_set)`, where
6079
@@ -67,6 +86,8 @@ The vertices are colored in a greedy fashion, following the order supplied.
6786
6887
If `postprocessing=true`, some colors might be replaced with `0` (the "neutral" color) as long as they are not needed during decompression.
6988
89+
The optional `forced_colors` keyword argument is used to enforce predefined vertex colors (e.g. coming from another optimization algorithm) but still run the star coloring procedure to verify correctness and build auxiliary data structures, useful during decompression.
90+
7091
# See also
7192
7293
- [`AdjacencyGraph`](@ref)
@@ -77,7 +98,10 @@ If `postprocessing=true`, some colors might be replaced with `0` (the "neutral"
7798
> [_New Acyclic and Star Coloring Algorithms with Application to Computing Hessians_](https://epubs.siam.org/doi/abs/10.1137/050639879), Gebremedhin et al. (2007), Algorithm 4.1
7899
"""
79100
function star_coloring(
80-
g::AdjacencyGraph{T}, vertices_in_order::AbstractVector{<:Integer}, postprocessing::Bool
101+
g::AdjacencyGraph{T},
102+
vertices_in_order::AbstractVector{<:Integer},
103+
postprocessing::Bool;
104+
forced_colors::Union{AbstractVector{<:Integer},Nothing}=nothing,
81105
) where {T<:Integer}
82106
# Initialize data structures
83107
nv = nb_vertices(g)
@@ -115,11 +139,16 @@ function star_coloring(
115139
end
116140
end
117141
end
118-
for i in eachindex(forbidden_colors)
119-
if forbidden_colors[i] != v
120-
color[v] = i
121-
break
142+
if isnothing(forced_colors)
143+
for i in eachindex(forbidden_colors)
144+
if forbidden_colors[i] != v
145+
color[v] = i
146+
break
147+
end
122148
end
149+
else
150+
@assert forbidden_colors[forced_colors[v]] != v
151+
color[v] = forced_colors[v]
123152
end
124153
_update_stars!(star, hub, g, v, color, first_neighbor)
125154
end
@@ -209,7 +238,10 @@ struct StarSet{T}
209238
end
210239

211240
"""
212-
acyclic_coloring(g::AdjacencyGraph, vertices_in_order::AbstractVector, postprocessing::Bool)
241+
acyclic_coloring(
242+
g::AdjacencyGraph, vertices_in_order::AbstractVector, postprocessing::Bool;
243+
forced_colors::Union{AbstractVector,Nothing}=nothing
244+
)
213245
214246
Compute an acyclic coloring of all vertices in the adjacency graph `g` and return a tuple `(color, tree_set)`, where
215247
@@ -222,6 +254,8 @@ The vertices are colored in a greedy fashion, following the order supplied.
222254
223255
If `postprocessing=true`, some colors might be replaced with `0` (the "neutral" color) as long as they are not needed during decompression.
224256
257+
The optional `forced_colors` keyword argument is used to enforce predefined vertex colors (e.g. coming from another optimization algorithm) but still run the acyclic coloring procedure to verify correctness and build auxiliary data structures, useful during decompression.
258+
225259
# See also
226260
227261
- [`AdjacencyGraph`](@ref)
@@ -232,7 +266,10 @@ If `postprocessing=true`, some colors might be replaced with `0` (the "neutral"
232266
> [_New Acyclic and Star Coloring Algorithms with Application to Computing Hessians_](https://epubs.siam.org/doi/abs/10.1137/050639879), Gebremedhin et al. (2007), Algorithm 3.1
233267
"""
234268
function acyclic_coloring(
235-
g::AdjacencyGraph{T}, vertices_in_order::AbstractVector{<:Integer}, postprocessing::Bool
269+
g::AdjacencyGraph{T},
270+
vertices_in_order::AbstractVector{<:Integer},
271+
postprocessing::Bool;
272+
forced_colors::Union{AbstractVector{<:Integer},Nothing}=nothing,
236273
) where {T<:Integer}
237274
# Initialize data structures
238275
nv = nb_vertices(g)
@@ -271,11 +308,16 @@ function acyclic_coloring(
271308
end
272309
end
273310
end
274-
for i in eachindex(forbidden_colors)
275-
if forbidden_colors[i] != v
276-
color[v] = i
277-
break
311+
if isnothing(forced_colors)
312+
for i in eachindex(forbidden_colors)
313+
if forbidden_colors[i] != v
314+
color[v] = i
315+
break
316+
end
278317
end
318+
else
319+
@assert forbidden_colors[forced_colors[v]] != v
320+
color[v] = forced_colors[v]
279321
end
280322
for (w, index_vw) in neighbors_with_edge_indices(g, v) # grow two-colored stars around the vertex v
281323
!has_diagonal(g) || (v == w && continue)

src/interface.jl

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -225,12 +225,13 @@ function _coloring(
225225
::ColoringProblem{:nonsymmetric,:column},
226226
algo::GreedyColoringAlgorithm,
227227
decompression_eltype::Type,
228-
symmetric_pattern::Bool,
228+
symmetric_pattern::Bool;
229+
forced_colors::Union{AbstractVector{<:Integer},Nothing}=nothing,
229230
)
230231
symmetric_pattern = symmetric_pattern || A isa Union{Symmetric,Hermitian}
231232
bg = BipartiteGraph(A; symmetric_pattern)
232233
vertices_in_order = vertices(bg, Val(2), algo.order)
233-
color = partial_distance2_coloring(bg, Val(2), vertices_in_order)
234+
color = partial_distance2_coloring(bg, Val(2), vertices_in_order; forced_colors)
234235
if speed_setting isa WithResult
235236
return ColumnColoringResult(A, bg, color)
236237
else
@@ -244,12 +245,13 @@ function _coloring(
244245
::ColoringProblem{:nonsymmetric,:row},
245246
algo::GreedyColoringAlgorithm,
246247
decompression_eltype::Type,
247-
symmetric_pattern::Bool,
248+
symmetric_pattern::Bool;
249+
forced_colors::Union{AbstractVector{<:Integer},Nothing}=nothing,
248250
)
249251
symmetric_pattern = symmetric_pattern || A isa Union{Symmetric,Hermitian}
250252
bg = BipartiteGraph(A; symmetric_pattern)
251253
vertices_in_order = vertices(bg, Val(1), algo.order)
252-
color = partial_distance2_coloring(bg, Val(1), vertices_in_order)
254+
color = partial_distance2_coloring(bg, Val(1), vertices_in_order; forced_colors)
253255
if speed_setting isa WithResult
254256
return RowColoringResult(A, bg, color)
255257
else
@@ -263,11 +265,14 @@ function _coloring(
263265
::ColoringProblem{:symmetric,:column},
264266
algo::GreedyColoringAlgorithm{:direct},
265267
decompression_eltype::Type,
266-
symmetric_pattern::Bool,
268+
symmetric_pattern::Bool;
269+
forced_colors::Union{AbstractVector{<:Integer},Nothing}=nothing,
267270
)
268271
ag = AdjacencyGraph(A; has_diagonal=true)
269272
vertices_in_order = vertices(ag, algo.order)
270-
color, star_set = star_coloring(ag, vertices_in_order, algo.postprocessing)
273+
color, star_set = star_coloring(
274+
ag, vertices_in_order, algo.postprocessing; forced_colors
275+
)
271276
if speed_setting isa WithResult
272277
return StarSetColoringResult(A, ag, color, star_set)
273278
else
@@ -281,11 +286,14 @@ function _coloring(
281286
::ColoringProblem{:symmetric,:column},
282287
algo::GreedyColoringAlgorithm{:substitution},
283288
decompression_eltype::Type{R},
284-
symmetric_pattern::Bool,
289+
symmetric_pattern::Bool;
290+
forced_colors::Union{AbstractVector{<:Integer},Nothing}=nothing,
285291
) where {R}
286292
ag = AdjacencyGraph(A; has_diagonal=true)
287293
vertices_in_order = vertices(ag, algo.order)
288-
color, tree_set = acyclic_coloring(ag, vertices_in_order, algo.postprocessing)
294+
color, tree_set = acyclic_coloring(
295+
ag, vertices_in_order, algo.postprocessing; forced_colors
296+
)
289297
if speed_setting isa WithResult
290298
return TreeSetColoringResult(A, ag, color, tree_set, R)
291299
else
@@ -299,12 +307,15 @@ function _coloring(
299307
::ColoringProblem{:nonsymmetric,:bidirectional},
300308
algo::GreedyColoringAlgorithm{:direct},
301309
decompression_eltype::Type{R},
302-
symmetric_pattern::Bool,
310+
symmetric_pattern::Bool;
311+
forced_colors::Union{AbstractVector{<:Integer},Nothing}=nothing,
303312
) where {R}
304313
A_and_Aᵀ, edge_to_index = bidirectional_pattern(A; symmetric_pattern)
305314
ag = AdjacencyGraph(A_and_Aᵀ, edge_to_index; has_diagonal=false)
306315
vertices_in_order = vertices(ag, algo.order)
307-
color, star_set = star_coloring(ag, vertices_in_order, algo.postprocessing)
316+
color, star_set = star_coloring(
317+
ag, vertices_in_order, algo.postprocessing; forced_colors
318+
)
308319
if speed_setting isa WithResult
309320
symmetric_result = StarSetColoringResult(A_and_Aᵀ, ag, color, star_set)
310321
return BicoloringResult(A, ag, symmetric_result, R)
@@ -322,12 +333,15 @@ function _coloring(
322333
::ColoringProblem{:nonsymmetric,:bidirectional},
323334
algo::GreedyColoringAlgorithm{:substitution},
324335
decompression_eltype::Type{R},
325-
symmetric_pattern::Bool,
336+
symmetric_pattern::Bool;
337+
forced_colors::Union{AbstractVector{<:Integer},Nothing}=nothing,
326338
) where {R}
327339
A_and_Aᵀ, edge_to_index = bidirectional_pattern(A; symmetric_pattern)
328340
ag = AdjacencyGraph(A_and_Aᵀ, edge_to_index; has_diagonal=false)
329341
vertices_in_order = vertices(ag, algo.order)
330-
color, tree_set = acyclic_coloring(ag, vertices_in_order, algo.postprocessing)
342+
color, tree_set = acyclic_coloring(
343+
ag, vertices_in_order, algo.postprocessing; forced_colors
344+
)
331345
if speed_setting isa WithResult
332346
symmetric_result = TreeSetColoringResult(A_and_Aᵀ, ag, color, tree_set, R)
333347
return BicoloringResult(A, ag, symmetric_result, R)

test/adtypes.jl

Lines changed: 42 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,49 @@
11
using ADTypes: ADTypes
22
using SparseArrays
3+
using LinearAlgebra
34
using SparseMatrixColorings
45
using Test
56

6-
@testset "Column coloring" begin
7-
problem = ColoringProblem(; structure=:nonsymmetric, partition=:column)
8-
algo = ADTypes.NoColoringAlgorithm()
9-
A = sprand(10, 20, 0.1)
10-
result = coloring(A, problem, algo)
11-
B = compress(A, result)
12-
@test size(B) == size(A)
13-
@test column_colors(result) == ADTypes.column_coloring(A, algo)
14-
@test decompress(B, result) == A
15-
end
7+
@testset "NoColoringAlgorithm" begin
8+
@testset "Column coloring" begin
9+
problem = ColoringProblem(; structure=:nonsymmetric, partition=:column)
10+
algo = ADTypes.NoColoringAlgorithm()
11+
A = sprand(10, 20, 0.3)
12+
result = coloring(A, problem, algo)
13+
B = compress(A, result)
14+
@test size(B) == size(A)
15+
@test column_colors(result) == ADTypes.column_coloring(A, algo)
16+
@test decompress(B, result) == A
17+
end
18+
19+
@testset "Row coloring" begin
20+
problem = ColoringProblem(; structure=:nonsymmetric, partition=:row)
21+
algo = ADTypes.NoColoringAlgorithm()
22+
A = sprand(10, 20, 0.3)
23+
result = coloring(A, problem, algo)
24+
B = compress(A, result)
25+
@test size(B) == size(A)
26+
@test row_colors(result) == ADTypes.row_coloring(A, algo)
27+
@test decompress(B, result) == A
28+
end
29+
30+
@testset "Symmetric coloring" begin
31+
problem = ColoringProblem(; structure=:symmetric, partition=:column)
32+
algo = ADTypes.NoColoringAlgorithm()
33+
A = Symmetric(sprand(20, 20, 0.3))
34+
result = coloring(A, problem, algo)
35+
B = compress(A, result)
36+
@test size(B) == size(A)
37+
@test column_colors(result) == ADTypes.column_coloring(A, algo)
38+
@test decompress(B, result) == A
39+
end
1640

17-
@testset "Row coloring" begin
18-
problem = ColoringProblem(; structure=:nonsymmetric, partition=:row)
19-
algo = ADTypes.NoColoringAlgorithm()
20-
A = sprand(10, 20, 0.1)
21-
result = coloring(A, problem, algo)
22-
B = compress(A, result)
23-
@test size(B) == size(A)
24-
@test row_colors(result) == ADTypes.row_coloring(A, algo)
25-
@test decompress(B, result) == A
41+
@testset "Bicoloring" begin
42+
problem = ColoringProblem(; structure=:nonsymmetric, partition=:bidirectional)
43+
algo = ADTypes.NoColoringAlgorithm()
44+
A = sprand(10, 20, 0.3)
45+
result = coloring(A, problem, algo)
46+
Br, Bc = compress(A, result)
47+
@test decompress(Br, Bc, result) == A
48+
end
2649
end

0 commit comments

Comments
 (0)