Skip to content

Commit 6cdcbf8

Browse files
committed
Start working on optimal coloring
1 parent 48c483a commit 6cdcbf8

7 files changed

Lines changed: 106 additions & 0 deletions

File tree

Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,14 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1515
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
1616
CliqueTrees = "60701a23-6482-424a-84db-faee86b9b1f8"
1717
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
18+
JuMP = "4076af6c-e467-56ae-b986-b466b2749572"
19+
MathOptInterface = "b8f27783-ece8-5eb3-8dc8-9495eed66fee"
1820

1921
[extensions]
2022
SparseMatrixColoringsCUDAExt = "CUDA"
2123
SparseMatrixColoringsCliqueTreesExt = "CliqueTrees"
2224
SparseMatrixColoringsColorsExt = "Colors"
25+
SparseMatrixColoringsJuMPExt = ["JuMP", "MathOptInterface"]
2326

2427
[compat]
2528
ADTypes = "1.2.1"
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
module SparseMatrixColoringsJuMPExt
2+
3+
using ADTypes: ADTypes
4+
using JuMP
5+
using LinearAlgebra
6+
import MathOptInterface as MOI
7+
using SparseArrays
8+
using SparseMatrixColorings:
9+
BipartiteGraph, OptimalColoringAlgorithm, nb_vertices, neighbors, pattern, vertices
10+
11+
function optimal_distance2_coloring(
12+
bg::BipartiteGraph, ::Val{side}, optimizer::Type{O}
13+
) where {side,O<:MOI.AbstractOptimizer}
14+
other_side = 3 - side
15+
n = nb_vertices(bg, Val(side))
16+
model = Model(optimizer)
17+
set_silent(model)
18+
@variable(model, 1 <= color[i=1:n] <= i, Int)
19+
@variable(model, ncolors, Int)
20+
@constraint(model, [ncolors; color] in MOI.CountDistinct(n + 1))
21+
for i in vertices(bg, Val(other_side))
22+
neigh = neighbors(bg, Val(other_side), i)
23+
@constraint(model, color[neigh] in MOI.AllDifferent(length(neigh)))
24+
end
25+
@objective(model, Min, ncolors)
26+
optimize!(model)
27+
assert_is_solved_and_feasible(model)
28+
return round.(Int, value.(color))
29+
end
30+
31+
function ADTypes.column_coloring(A::AbstractMatrix, algo::OptimalColoringAlgorithm)
32+
bg = BipartiteGraph(A)
33+
return optimal_distance2_coloring(bg, Val(2), algo.optimizer)
34+
end
35+
36+
function ADTypes.row_coloring(A::AbstractMatrix, algo::OptimalColoringAlgorithm)
37+
bg = BipartiteGraph(A)
38+
return optimal_distance2_coloring(bg, Val(1), algo.optimizer)
39+
end
40+
41+
end

src/SparseMatrixColorings.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ include("decompression.jl")
5656
include("check.jl")
5757
include("examples.jl")
5858
include("show_colors.jl")
59+
include("optimal.jl")
5960

6061
include("precompile.jl")
6162

@@ -64,6 +65,7 @@ export DynamicDegreeBasedOrder, SmallestLast, IncidenceDegree, DynamicLargestFir
6465
export PerfectEliminationOrder
6566
export ColoringProblem, GreedyColoringAlgorithm, AbstractColoringResult
6667
export ConstantColoringAlgorithm
68+
export OptimalColoringAlgorithm
6769
export coloring, fast_coloring
6870
export column_colors, row_colors, ncolors
6971
export column_groups, row_groups

src/optimal.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
"""
2+
OptimalColoringAlgorithm
3+
4+
Coloring algorithm that relies on mathematical programming with [JuMP](https://jump.dev/) to find an optimal coloring.
5+
6+
!!! warning
7+
This algorithm is only available when JuMP is loaded. If you encounter a method error, run `import JuMP` in your REPL and try again.
8+
9+
# Constructor
10+
11+
OptimalColoringAlgorithm(optimizer)
12+
13+
The `optimizer` argument can be any JuMP-compatible optimizer, like `HiGHS.Optimizer`.
14+
You can use [`optimizer_with_attributes`](https://jump.dev/JuMP.jl/stable/api/JuMP/#optimizer_with_attributes) to set solver-specific parameters.
15+
"""
16+
struct OptimalColoringAlgorithm{S} <: ADTypes.AbstractColoringAlgorithm
17+
optimizer::S
18+
end

test/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@ CliqueTrees = "60701a23-6482-424a-84db-faee86b9b1f8"
1212
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
1313
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
1414
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
15+
HiGHS = "87dc4568-4c63-4d18-b0c0-bb2238e4078b"
1516
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
17+
JuMP = "4076af6c-e467-56ae-b986-b466b2749572"
1618
JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"
1719
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1820
MatrixDepot = "b51810bb-c9f3-55da-ae3c-350fc1fbce05"

test/optimal.jl

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
using SparseArrays
2+
using SparseMatrixColorings
3+
using StableRNGs
4+
using Test
5+
6+
rng = StableRNG(0)
7+
8+
asymmetric_params = vcat(
9+
[(10, 20, p) for p in (0.0:0.1:0.5)],
10+
[(20, 10, p) for p in (0.0:0.1:0.5)],
11+
[(100, 200, p) for p in (0.01:0.01:0.05)],
12+
[(200, 100, p) for p in (0.01:0.01:0.05)],
13+
)
14+
15+
@testset "Column coloring" begin
16+
problem = ColoringProblem(; structure=:nonsymmetric, partition=:column)
17+
algo = GreedyColoringAlgorithm()
18+
optalgo = OptimalColoringAlgorithm(HiGHS.Optimizer)
19+
for (m, n, p) in asymmetric_params
20+
A = sprand(rng, m, n, p)
21+
result = coloring(A, problem, algo)
22+
optresult = coloring(A, problem, optalgo)
23+
@test ncolors(result) >= ncolors(optresult)
24+
end
25+
end
26+
27+
@testset "Row coloring" begin
28+
problem = ColoringProblem(; structure=:nonsymmetric, partition=:row)
29+
algo = GreedyColoringAlgorithm()
30+
optalgo = OptimalColoringAlgorithm(HiGHS.Optimizer)
31+
for (m, n, p) in asymmetric_params
32+
A = sprand(rng, m, n, p)
33+
result = coloring(A, problem, algo)
34+
optresult = coloring(A, problem, optalgo)
35+
@test ncolors(result) >= ncolors(optresult)
36+
end
37+
end

test/runtests.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,9 @@ include("utils.jl")
5858
@testset "Constant coloring" begin
5959
include("constant.jl")
6060
end
61+
@testset "Optimal coloring" begin
62+
include("optimal.jl")
63+
end
6164
@testset "ADTypes coloring algorithms" begin
6265
include("adtypes.jl")
6366
end

0 commit comments

Comments
 (0)