Skip to content

Commit 4d01552

Browse files
committed
Fix StackOverflow
1 parent dadb75e commit 4d01552

2 files changed

Lines changed: 18 additions & 25 deletions

File tree

src/constant.jl

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -99,35 +99,36 @@ function ConstantColoringAlgorithm(
9999
return ConstantColoringAlgorithm{partition}(matrix_template, color)
100100
end
101101

102-
function coloring(
103-
A::AbstractMatrix,
104-
::ColoringProblem{:nonsymmetric,partition},
105-
algo::ConstantColoringAlgorithm{partition};
106-
decompression_eltype::Type=Float64,
107-
symmetric_pattern::Bool=false,
108-
) where {partition}
109-
(; matrix_template, result) = algo
102+
function check_template(algo::ConstantColoringAlgorithm, A::AbstractMatrix)
103+
(; matrix_template) = algo
110104
if size(A) != size(matrix_template)
111105
throw(
112106
DimensionMismatch(
113107
"`ConstantColoringAlgorithm` expected matrix of size $(size(matrix_template)) but got matrix of size $(size(A))",
114108
),
115109
)
116-
else
117-
return result
118110
end
119111
end
120112

113+
function coloring(
114+
A::AbstractMatrix,
115+
::ColoringProblem{:nonsymmetric,partition},
116+
algo::ConstantColoringAlgorithm{partition};
117+
decompression_eltype::Type=Float64,
118+
symmetric_pattern::Bool=false,
119+
) where {partition}
120+
check_template(algo, A)
121+
return algo.result
122+
end
123+
121124
function ADTypes.column_coloring(
122125
A::AbstractMatrix, algo::ConstantColoringAlgorithm{:column}
123126
)
124-
problem = ColoringProblem{:nonsymmetric,:column}()
125-
result = coloring(A, problem, algo)
126-
return column_colors(result)
127+
check_template(algo, A)
128+
return column_colors(algo.result)
127129
end
128130

129-
function ADTypes.row_coloring(A::AbstractMatrix, algo::ConstantColoringAlgorithm)
130-
problem = ColoringProblem{:nonsymmetric,:row}()
131-
result = coloring(A, problem, algo)
132-
return row_colors(result)
131+
function ADTypes.row_coloring(A::AbstractMatrix, algo::ConstantColoringAlgorithm{:row})
132+
check_template(algo, A)
133+
return row_colors(algo.result)
133134
end

test/constant.jl

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,3 @@ end
2727
@test ADTypes.row_coloring(matrix_template, algo) == color
2828
@test_throws MethodError ADTypes.column_coloring(matrix_template, algo)
2929
end
30-
31-
@testset "Symmetric coloring" begin
32-
wrong_problem = ColoringProblem(; structure=:symmetric, partition=:column)
33-
color = rand(1:5, size(matrix_template, 2))
34-
algo = ConstantColoringAlgorithm(matrix_template, color; partition=:column)
35-
@test_throws MethodError coloring(matrix_template, wrong_problem, algo)
36-
@test_throws MethodError ADTypes.symmetric_coloring(matrix_template, algo)
37-
end

0 commit comments

Comments
 (0)