Skip to content

Commit 1d64277

Browse files
authored
Implement ConstantColoringAlgorithm (#127)
* Implement ConstantColoringAlgorithm * Improve docs and add ADTypes
1 parent 580cc45 commit 1d64277

5 files changed

Lines changed: 173 additions & 0 deletions

File tree

docs/src/api.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ SparseMatrixColorings
1717
coloring
1818
ColoringProblem
1919
GreedyColoringAlgorithm
20+
ConstantColoringAlgorithm
2021
```
2122

2223
## Result analysis

src/SparseMatrixColorings.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,12 +49,14 @@ include("coloring.jl")
4949
include("result.jl")
5050
include("matrices.jl")
5151
include("interface.jl")
52+
include("constant.jl")
5253
include("decompression.jl")
5354
include("check.jl")
5455
include("examples.jl")
5556

5657
export NaturalOrder, RandomOrder, LargestFirst
5758
export ColoringProblem, GreedyColoringAlgorithm, AbstractColoringResult
59+
export ConstantColoringAlgorithm
5860
export coloring
5961
export column_colors, row_colors
6062
export column_groups, row_groups

src/constant.jl

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
"""
2+
ConstantColoringAlgorithm{partition} <: ADTypes.AbstractColoringAlgorithm
3+
4+
Coloring algorithm which always returns the same precomputed vector of colors.
5+
Useful when the optimal coloring of a matrix can be determined a priori due to its specific structure (e.g. banded).
6+
7+
It is passed as an argument to the main function [`coloring`](@ref), but will only work if the associated `problem` has `:nonsymmetric` structure.
8+
Indeed, for symmetric coloring problems, we need more than just the vector of colors to allow fast decompression.
9+
10+
# Constructors
11+
12+
ConstantColoringAlgorithm{partition}(matrix_template, color)
13+
ConstantColoringAlgorithm(matrix_template, color; partition=:column)
14+
15+
- `partition::Symbol`: either `:row` or `:column`.
16+
- `matrix_template::AbstractMatrix`: matrix for which the vector of colors was precomputed (the algorithm will only accept matrices of the exact same size).
17+
- `color::Vector{Int}`: vector of integer colors, one for each row or column (depending on `partition`).
18+
19+
!!! warning
20+
The second constructor (based on keyword arguments) is type-unstable.
21+
22+
We do not necessarily verify consistency between the matrix template and the vector of colors, this is the responsibility of the user.
23+
24+
# Example
25+
26+
```jldoctest
27+
julia> using SparseMatrixColorings, LinearAlgebra
28+
29+
julia> matrix_template = Diagonal(ones(Bool, 5))
30+
5×5 Diagonal{Bool, Vector{Bool}}:
31+
1 ⋅ ⋅ ⋅ ⋅
32+
⋅ 1 ⋅ ⋅ ⋅
33+
⋅ ⋅ 1 ⋅ ⋅
34+
⋅ ⋅ ⋅ 1 ⋅
35+
⋅ ⋅ ⋅ ⋅ 1
36+
37+
julia> color = ones(Int, 5) # coloring a Diagonal is trivial
38+
5-element Vector{Int64}:
39+
1
40+
1
41+
1
42+
1
43+
1
44+
45+
julia> problem = ColoringProblem(; structure=:nonsymmetric, partition=:column);
46+
47+
julia> algo = ConstantColoringAlgorithm(matrix_template, color; partition=:column);
48+
49+
julia> result = coloring(similar(matrix_template), problem, algo);
50+
51+
julia> column_colors(result)
52+
5-element Vector{Int64}:
53+
1
54+
1
55+
1
56+
1
57+
1
58+
```
59+
60+
# ADTypes coloring interface
61+
62+
`ConstantColoringAlgorithm` is a subtype of [`ADTypes.AbstractColoringAlgorithm`](@extref ADTypes.AbstractColoringAlgorithm), which means the following methods are also applicable (although they will error if the kind of coloring demanded not consistent):
63+
64+
- [`ADTypes.column_coloring`](@extref ADTypes.column_coloring)
65+
- [`ADTypes.row_coloring`](@extref ADTypes.row_coloring)
66+
"""
67+
struct ConstantColoringAlgorithm{
68+
partition,M<:AbstractMatrix,R<:AbstractColoringResult{:nonsymmetric,partition,:direct}
69+
} <: ADTypes.AbstractColoringAlgorithm
70+
matrix_template::M
71+
color::Vector{Int}
72+
result::R
73+
end
74+
75+
function ConstantColoringAlgorithm{:column}(
76+
matrix_template::AbstractMatrix, color::Vector{Int}
77+
)
78+
S = convert(SparseMatrixCSC, matrix_template)
79+
result = ColumnColoringResult(S, color)
80+
M, R = typeof(matrix_template), typeof(result)
81+
return ConstantColoringAlgorithm{:column,M,R}(matrix_template, color, result)
82+
end
83+
84+
function ConstantColoringAlgorithm{:row}(
85+
matrix_template::AbstractMatrix, color::Vector{Int}
86+
)
87+
S = convert(SparseMatrixCSC, matrix_template)
88+
result = RowColoringResult(S, color)
89+
M, R = typeof(matrix_template), typeof(result)
90+
return ConstantColoringAlgorithm{:row,M,R}(matrix_template, color, result)
91+
end
92+
93+
function ConstantColoringAlgorithm(
94+
matrix_template::AbstractMatrix, color::Vector{Int}; partition=:column
95+
)
96+
return ConstantColoringAlgorithm{partition}(matrix_template, color)
97+
end
98+
99+
function coloring(
100+
A::AbstractMatrix,
101+
::ColoringProblem{:nonsymmetric,partition},
102+
algo::ConstantColoringAlgorithm{partition};
103+
decompression_eltype::Type=Float64,
104+
symmetric_pattern::Bool=false,
105+
) where {partition}
106+
@compat (; matrix_template, result) = algo
107+
if size(A) != size(matrix_template)
108+
throw(
109+
DimensionMismatch(
110+
"`ConstantColoringAlgorithm` expected matrix of size $(size(matrix_template)) but got matrix of size $(size(A))",
111+
),
112+
)
113+
else
114+
return result
115+
end
116+
end
117+
118+
function ADTypes.column_coloring(
119+
A::AbstractMatrix, algo::ConstantColoringAlgorithm{:column}
120+
)
121+
problem = ColoringProblem{:nonsymmetric,:column}()
122+
result = coloring(A, problem, algo)
123+
return column_colors(result)
124+
end
125+
126+
function ADTypes.row_coloring(A::AbstractMatrix, algo::ConstantColoringAlgorithm)
127+
problem = ColoringProblem{:nonsymmetric,:row}()
128+
result = coloring(A, problem, algo)
129+
return row_colors(result)
130+
end

test/constant.jl

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
using ADTypes: ADTypes
2+
using SparseMatrixColorings
3+
using Test
4+
5+
matrix_template = ones(100, 200)
6+
7+
@testset "Column coloring" begin
8+
problem = ColoringProblem(; structure=:nonsymmetric, partition=:column)
9+
color = rand(1:5, size(matrix_template, 2))
10+
algo = ConstantColoringAlgorithm(matrix_template, color; partition=:column)
11+
wrong_algo = ConstantColoringAlgorithm(matrix_template, color; partition=:row)
12+
@test_throws DimensionMismatch coloring(transpose(matrix_template), problem, algo)
13+
@test_throws MethodError coloring(matrix_template, problem, wrong_algo)
14+
result = coloring(matrix_template, problem, algo)
15+
@test column_colors(result) == color
16+
@test ADTypes.column_coloring(matrix_template, algo) == color
17+
@test_throws MethodError ADTypes.row_coloring(matrix_template, algo)
18+
end
19+
20+
@testset "Row coloring" begin
21+
problem = ColoringProblem(; structure=:nonsymmetric, partition=:row)
22+
color = rand(1:5, size(matrix_template, 1))
23+
algo = ConstantColoringAlgorithm(matrix_template, color; partition=:row)
24+
@test_throws DimensionMismatch coloring(transpose(matrix_template), problem, algo)
25+
result = coloring(matrix_template, problem, algo)
26+
@test row_colors(result) == color
27+
@test ADTypes.row_coloring(matrix_template, algo) == color
28+
@test_throws MethodError ADTypes.column_coloring(matrix_template, algo)
29+
end
30+
31+
@testset "Symmetric coloring" begin
32+
wrong_problem = ColoringProblem(; structure=:symmetric, partition=:column)
33+
color = rand(1:5, size(matrix_template, 2))
34+
algo = ConstantColoringAlgorithm(matrix_template, color; partition=:column)
35+
@test_throws MethodError coloring(matrix_template, wrong_problem, algo)
36+
@test_throws MethodError ADTypes.symmetric_coloring(matrix_template, algo)
37+
end

test/runtests.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,9 @@ include("utils.jl")
4242
@testset "Constructors" begin
4343
include("constructors.jl")
4444
end
45+
@testset "Constant coloring" begin
46+
include("constant.jl")
47+
end
4548
end
4649
@testset verbose = true "Correctness" begin
4750
@testset "Small instances" begin

0 commit comments

Comments
 (0)