@@ -99,35 +99,36 @@ function ConstantColoringAlgorithm(
9999 return ConstantColoringAlgorithm {partition} (matrix_template, color)
100100end
101101
102- function coloring (
103- A:: AbstractMatrix ,
104- :: ColoringProblem{:nonsymmetric,partition} ,
105- algo:: ConstantColoringAlgorithm{partition} ;
106- decompression_eltype:: Type = Float64,
107- symmetric_pattern:: Bool = false ,
108- ) where {partition}
109- (; matrix_template, result) = algo
102+ function check_template (algo:: ConstantColoringAlgorithm , A:: AbstractMatrix )
103+ (; matrix_template) = algo
110104 if size (A) != size (matrix_template)
111105 throw (
112106 DimensionMismatch (
113107 " `ConstantColoringAlgorithm` expected matrix of size $(size (matrix_template)) but got matrix of size $(size (A)) " ,
114108 ),
115109 )
116- else
117- return result
118110 end
119111end
120112
113+ function coloring (
114+ A:: AbstractMatrix ,
115+ :: ColoringProblem{:nonsymmetric,partition} ,
116+ algo:: ConstantColoringAlgorithm{partition} ;
117+ decompression_eltype:: Type = Float64,
118+ symmetric_pattern:: Bool = false ,
119+ ) where {partition}
120+ check_template (algo, A)
121+ return algo. result
122+ end
123+
121124function ADTypes. column_coloring (
122125 A:: AbstractMatrix , algo:: ConstantColoringAlgorithm{:column}
123126)
124- problem = ColoringProblem {:nonsymmetric,:column} ()
125- result = coloring (A, problem, algo)
126- return column_colors (result)
127+ check_template (algo, A)
128+ return column_colors (algo. result)
127129end
128130
129- function ADTypes. row_coloring (A:: AbstractMatrix , algo:: ConstantColoringAlgorithm )
130- problem = ColoringProblem {:nonsymmetric,:row} ()
131- result = coloring (A, problem, algo)
132- return row_colors (result)
131+ function ADTypes. row_coloring (A:: AbstractMatrix , algo:: ConstantColoringAlgorithm{:row} )
132+ check_template (algo, A)
133+ return row_colors (algo. result)
133134end
0 commit comments