Skip to content

Commit 0e5c18b

Browse files
committed
Improve feasibility check
1 parent b344414 commit 0e5c18b

4 files changed

Lines changed: 91 additions & 50 deletions

File tree

src/adtypes.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ function coloring(
1414
elseif partition == :row
1515
forced_colors = ADTypes.row_coloring(A, algo)
1616
else
17+
# TODO: improve once https://github.com/SciML/ADTypes.jl/issues/69 is done
1718
A_and_Aᵀ, _ = bidirectional_pattern(A; symmetric_pattern)
1819
forced_colors = ADTypes.symmetric_coloring(A_and_Aᵀ, algo)
1920
end

src/coloring.jl

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
struct InvalidColoringError <: Exception end
2+
13
"""
24
partial_distance2_coloring(
35
bg::BipartiteGraph, ::Val{side}, vertices_in_order::AbstractVector;
@@ -63,8 +65,11 @@ function partial_distance2_coloring!(
6365
end
6466
end
6567
else
66-
@assert forbidden_colors[forced_colors[v]] != v
67-
color[v] = forced_colors[v]
68+
if forbidden_colors[forced_colors[v]] == v
69+
throw(InvalidColoringError())
70+
else
71+
color[v] = forced_colors[v]
72+
end
6873
end
6974
end
7075
end
@@ -147,8 +152,11 @@ function star_coloring(
147152
end
148153
end
149154
else
150-
@assert forbidden_colors[forced_colors[v]] != v
151-
color[v] = forced_colors[v]
155+
if forbidden_colors[forced_colors[v]] == v
156+
throw(InvalidColoringError())
157+
else
158+
color[v] = forced_colors[v]
159+
end
152160
end
153161
_update_stars!(star, hub, g, v, color, first_neighbor)
154162
end
@@ -316,8 +324,11 @@ function acyclic_coloring(
316324
end
317325
end
318326
else
319-
@assert forbidden_colors[forced_colors[v]] != v
320-
color[v] = forced_colors[v]
327+
if forbidden_colors[forced_colors[v]] == v
328+
throw(InvalidColoringError())
329+
else
330+
color[v] = forced_colors[v]
331+
end
321332
end
322333
for (w, index_vw) in neighbors_with_edge_indices(g, v) # grow two-colored stars around the vertex v
323334
!has_diagonal(g) || (v == w && continue)

src/constant.jl

Lines changed: 40 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,24 @@
44
Coloring algorithm which always returns the same precomputed vector of colors.
55
Useful when the optimal coloring of a matrix can be determined a priori due to its specific structure (e.g. banded).
66
7-
It is passed as an argument to the main function [`coloring`](@ref), but will only work if the associated `problem` has `:nonsymmetric` structure.
8-
Indeed, for symmetric coloring problems, we need more than just the vector of colors to allow fast decompression.
7+
It is passed as an argument to the main function [`coloring`](@ref), but will only work if the associated `problem` has a `:column` or `:row` partition.
98
109
# Constructors
1110
1211
ConstantColoringAlgorithm{partition}(matrix_template, color)
13-
ConstantColoringAlgorithm(matrix_template, color; partition=:column)
12+
ConstantColoringAlgorithm{partition,structure}(matrix_template, color)
13+
ConstantColoringAlgorithm(
14+
matrix_template, color;
15+
structure=:nonsymmetric, partition=:column
16+
)
1417
1518
- `partition::Symbol`: either `:row` or `:column`.
19+
- `structure::Symbol`: either `:nonsymmetric` or `:symmetric`.
1620
- `matrix_template::AbstractMatrix`: matrix for which the vector of colors was precomputed (the algorithm will only accept matrices of the exact same size).
1721
- `color::Vector{<:Integer}`: vector of integer colors, one for each row or column (depending on `partition`).
1822
1923
!!! warning
20-
The second constructor (based on keyword arguments) is type-unstable.
24+
The constructor based on keyword arguments is type-unstable if these arguments are not compile-time constants.
2125
2226
We do not necessarily verify consistency between the matrix template and the vector of colors, this is the responsibility of the user.
2327
@@ -63,40 +67,36 @@ julia> column_colors(result)
6367
6468
- [`ADTypes.column_coloring`](@extref ADTypes.column_coloring)
6569
- [`ADTypes.row_coloring`](@extref ADTypes.row_coloring)
70+
- [`ADTypes.symmetric_coloring`](@extref ADTypes.symmetric_coloring)
6671
"""
67-
struct ConstantColoringAlgorithm{
68-
partition,
69-
M<:AbstractMatrix,
70-
T<:Integer,
71-
R<:AbstractColoringResult{:nonsymmetric,partition,:direct},
72-
} <: ADTypes.AbstractColoringAlgorithm
72+
struct ConstantColoringAlgorithm{partition,structure,M<:AbstractMatrix,T<:Integer} <:
73+
ADTypes.AbstractColoringAlgorithm
7374
matrix_template::M
7475
color::Vector{T}
75-
result::R
76-
end
7776

78-
function ConstantColoringAlgorithm{:column}(
79-
matrix_template::AbstractMatrix, color::Vector{<:Integer}
80-
)
81-
bg = BipartiteGraph(matrix_template)
82-
result = ColumnColoringResult(matrix_template, bg, color)
83-
T, M, R = eltype(bg), typeof(matrix_template), typeof(result)
84-
return ConstantColoringAlgorithm{:column,M,T,R}(matrix_template, color, result)
77+
function ConstantColoringAlgorithm{partition,structure}(
78+
matrix_template::AbstractMatrix, color::Vector{<:Integer}
79+
) where {partition,structure}
80+
check_valid_problem(structure, partition)
81+
return new{partition,structure,typeof(matrix_template),eltype(color)}(
82+
matrix_template, color
83+
)
84+
end
8585
end
8686

87-
function ConstantColoringAlgorithm{:row}(
87+
function ConstantColoringAlgorithm{partition}(
8888
matrix_template::AbstractMatrix, color::Vector{<:Integer}
89-
)
90-
bg = BipartiteGraph(matrix_template)
91-
result = RowColoringResult(matrix_template, bg, color)
92-
T, M, R = eltype(bg), typeof(matrix_template), typeof(result)
93-
return ConstantColoringAlgorithm{:row,M,T,R}(matrix_template, color, result)
89+
) where {partition}
90+
return ConstantColoringAlgorithm{partition,:nonsymmetric}(matrix_template, color)
9491
end
9592

9693
function ConstantColoringAlgorithm(
97-
matrix_template::AbstractMatrix, color::Vector{<:Integer}; partition::Symbol=:column
94+
matrix_template::AbstractMatrix,
95+
color::Vector{<:Integer};
96+
structure::Symbol=:nonsymmetric,
97+
partition::Symbol=:column,
9898
)
99-
return ConstantColoringAlgorithm{partition}(matrix_template, color)
99+
return ConstantColoringAlgorithm{partition,structure}(matrix_template, color)
100100
end
101101

102102
function check_template(algo::ConstantColoringAlgorithm, A::AbstractMatrix)
@@ -110,25 +110,25 @@ function check_template(algo::ConstantColoringAlgorithm, A::AbstractMatrix)
110110
end
111111
end
112112

113-
function coloring(
114-
A::AbstractMatrix,
115-
::ColoringProblem{:nonsymmetric,partition},
116-
algo::ConstantColoringAlgorithm{partition};
117-
decompression_eltype::Type=Float64,
118-
symmetric_pattern::Bool=false,
119-
) where {partition}
113+
function ADTypes.column_coloring(
114+
A::AbstractMatrix, algo::ConstantColoringAlgorithm{:column,:nonsymmetric}
115+
)
120116
check_template(algo, A)
121-
return algo.result
117+
return algo.color
122118
end
123119

124-
function ADTypes.column_coloring(
125-
A::AbstractMatrix, algo::ConstantColoringAlgorithm{:column}
120+
function ADTypes.row_coloring(
121+
A::AbstractMatrix, algo::ConstantColoringAlgorithm{:row,:nonsymmetric}
126122
)
127123
check_template(algo, A)
128-
return column_colors(algo.result)
124+
return algo.color
129125
end
130126

131-
function ADTypes.row_coloring(A::AbstractMatrix, algo::ConstantColoringAlgorithm{:row})
127+
function ADTypes.symmetric_coloring(
128+
A::AbstractMatrix, algo::ConstantColoringAlgorithm{:column,:symmetric}
129+
)
132130
check_template(algo, A)
133-
return row_colors(algo.result)
131+
return algo.color
134132
end
133+
134+
# TODO: handle bidirectional once https://github.com/SciML/ADTypes.jl/issues/69 is done

test/constant.jl

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,20 @@
11
using ADTypes: ADTypes
22
using SparseMatrixColorings
3+
using SparseMatrixColorings: InvalidColoringError
34
using Test
45

5-
matrix_template = ones(100, 200)
6+
matrix_template = ones(Bool, 10, 20)
7+
sym_matrix_template = ones(Bool, 10, 10)
68

79
@testset "Column coloring" begin
810
problem = ColoringProblem(; structure=:nonsymmetric, partition=:column)
9-
color = rand(1:5, size(matrix_template, 2))
11+
color = collect(1:20)
1012
algo = ConstantColoringAlgorithm(matrix_template, color; partition=:column)
11-
wrong_algo = ConstantColoringAlgorithm(matrix_template, color; partition=:row)
13+
wrong_algo = ConstantColoringAlgorithm{:row}(matrix_template, color)
14+
wrong_color = ConstantColoringAlgorithm{:column}(matrix_template, ones(Int, 20))
1215
@test_throws DimensionMismatch coloring(transpose(matrix_template), problem, algo)
1316
@test_throws MethodError coloring(matrix_template, problem, wrong_algo)
17+
@test_throws InvalidColoringError coloring(matrix_template, problem, wrong_color)
1418
result = coloring(matrix_template, problem, algo)
1519
@test column_colors(result) == color
1620
@test ADTypes.column_coloring(matrix_template, algo) == color
@@ -19,11 +23,36 @@ end
1923

2024
@testset "Row coloring" begin
2125
problem = ColoringProblem(; structure=:nonsymmetric, partition=:row)
22-
color = rand(1:5, size(matrix_template, 1))
26+
color = collect(1:10)
2327
algo = ConstantColoringAlgorithm(matrix_template, color; partition=:row)
28+
wrong_algo = ConstantColoringAlgorithm{:column}(matrix_template, color)
29+
wrong_color = ConstantColoringAlgorithm{:row}(matrix_template, ones(Int, 10))
2430
@test_throws DimensionMismatch coloring(transpose(matrix_template), problem, algo)
31+
@test_throws MethodError coloring(matrix_template, problem, wrong_algo)
32+
@test_throws InvalidColoringError coloring(matrix_template, problem, wrong_color)
2533
result = coloring(matrix_template, problem, algo)
2634
@test row_colors(result) == color
2735
@test ADTypes.row_coloring(matrix_template, algo) == color
2836
@test_throws MethodError ADTypes.column_coloring(matrix_template, algo)
2937
end
38+
39+
@testset "Symmetric coloring" begin
40+
problem = ColoringProblem(; structure=:symmetric, partition=:column)
41+
color = collect(1:10)
42+
algo = ConstantColoringAlgorithm(
43+
sym_matrix_template, color; partition=:column, structure=:symmetric
44+
)
45+
wrong_algo = ConstantColoringAlgorithm{:column,:nonsymmetric}(
46+
sym_matrix_template, color
47+
)
48+
wrong_color = ConstantColoringAlgorithm{:column,:symmetric}(
49+
sym_matrix_template, ones(Int, 20)
50+
)
51+
@test_throws DimensionMismatch coloring(matrix_template, problem, algo)
52+
@test_throws MethodError coloring(sym_matrix_template, problem, wrong_algo)
53+
@test_throws InvalidColoringError coloring(sym_matrix_template, problem, wrong_color)
54+
result = coloring(sym_matrix_template, problem, algo)
55+
@test column_colors(result) == color
56+
@test ADTypes.symmetric_coloring(sym_matrix_template, algo) == color
57+
@test_throws MethodError ADTypes.column_coloring(sym_matrix_template, algo)
58+
end

0 commit comments

Comments
 (0)