Skip to content

Commit 347e297

Browse files
gdalleamontoison
authored andcommitted
Store graph in result to allow generic matrices
1 parent a03a455 commit 347e297

8 files changed

Lines changed: 165 additions & 118 deletions

File tree

src/constant.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -75,17 +75,17 @@ end
7575
function ConstantColoringAlgorithm{:column}(
7676
matrix_template::AbstractMatrix, color::Vector{Int}
7777
)
78-
S = convert(SparseMatrixCSC, matrix_template)
79-
result = ColumnColoringResult(S, color)
78+
bg = BipartiteGraph(matrix_template)
79+
result = ColumnColoringResult(matrix_template, bg, color)
8080
M, R = typeof(matrix_template), typeof(result)
8181
return ConstantColoringAlgorithm{:column,M,R}(matrix_template, color, result)
8282
end
8383

8484
function ConstantColoringAlgorithm{:row}(
8585
matrix_template::AbstractMatrix, color::Vector{Int}
8686
)
87-
S = convert(SparseMatrixCSC, matrix_template)
88-
result = RowColoringResult(S, color)
87+
bg = BipartiteGraph(matrix_template)
88+
result = RowColoringResult(matrix_template, bg, color)
8989
M, R = typeof(matrix_template), typeof(result)
9090
return ConstantColoringAlgorithm{:row,M,R}(matrix_template, color, result)
9191
end

src/decompression.jl

Lines changed: 54 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -115,9 +115,8 @@ true
115115
- [`ColoringProblem`](@ref)
116116
- [`AbstractColoringResult`](@ref)
117117
"""
118-
function decompress(B::AbstractMatrix{R}, result::AbstractColoringResult) where {R<:Real}
119-
@compat (; S) = result
120-
A = respectful_similar(S, R)
118+
function decompress(B::AbstractMatrix, result::AbstractColoringResult)
119+
A = respectful_similar(result.A, eltype(B))
121120
return decompress!(A, B, result)
122121
end
123122

@@ -264,12 +263,11 @@ end
264263

265264
## ColumnColoringResult
266265

267-
function decompress!(
268-
A::AbstractMatrix{R}, B::AbstractMatrix{R}, result::ColumnColoringResult
269-
) where {R<:Real}
270-
@compat (; S, color) = result
266+
function decompress!(A::AbstractMatrix, B::AbstractMatrix, result::ColumnColoringResult)
267+
@compat (; color) = result
268+
S = result.bg.S2
271269
check_same_pattern(A, S)
272-
A .= zero(R)
270+
fill!(A, zero(eltype(A)))
273271
rvS = rowvals(S)
274272
for j in axes(S, 2)
275273
cj = color[j]
@@ -282,9 +280,10 @@ function decompress!(
282280
end
283281

284282
function decompress_single_color!(
285-
A::AbstractMatrix{R}, b::AbstractVector{R}, c::Integer, result::ColumnColoringResult
286-
) where {R<:Real}
287-
@compat (; S, group) = result
283+
A::AbstractMatrix, b::AbstractVector, c::Integer, result::ColumnColoringResult
284+
)
285+
@compat (; group) = result
286+
S = result.bg.S2
288287
check_same_pattern(A, S)
289288
rvS = rowvals(S)
290289
for j in group[c]
@@ -296,10 +295,9 @@ function decompress_single_color!(
296295
return A
297296
end
298297

299-
function decompress!(
300-
A::SparseMatrixCSC{R}, B::AbstractMatrix{R}, result::ColumnColoringResult
301-
) where {R<:Real}
302-
@compat (; S, compressed_indices) = result
298+
function decompress!(A::SparseMatrixCSC, B::AbstractMatrix, result::ColumnColoringResult)
299+
@compat (; compressed_indices) = result
300+
S = result.bg.S2
303301
check_same_pattern(A, S)
304302
nzA = nonzeros(A)
305303
for k in eachindex(nzA, compressed_indices)
@@ -309,9 +307,10 @@ function decompress!(
309307
end
310308

311309
function decompress_single_color!(
312-
A::SparseMatrixCSC{R}, b::AbstractVector{R}, c::Integer, result::ColumnColoringResult
313-
) where {R<:Real}
314-
@compat (; S, group) = result
310+
A::SparseMatrixCSC, b::AbstractVector, c::Integer, result::ColumnColoringResult
311+
)
312+
@compat (; group) = result
313+
S = result.bg.S2
315314
check_same_pattern(A, S)
316315
rvS = rowvals(S)
317316
nzA = nonzeros(A)
@@ -326,12 +325,11 @@ end
326325

327326
## RowColoringResult
328327

329-
function decompress!(
330-
A::AbstractMatrix{R}, B::AbstractMatrix{R}, result::RowColoringResult
331-
) where {R<:Real}
332-
@compat (; S, color) = result
328+
function decompress!(A::AbstractMatrix, B::AbstractMatrix, result::RowColoringResult)
329+
@compat (; color) = result
330+
S = result.bg.S2
333331
check_same_pattern(A, S)
334-
A .= zero(R)
332+
fill!(A, zero(eltype(A)))
335333
rvS = rowvals(S)
336334
for j in axes(S, 2)
337335
for k in nzrange(S, j)
@@ -344,9 +342,10 @@ function decompress!(
344342
end
345343

346344
function decompress_single_color!(
347-
A::AbstractMatrix{R}, b::AbstractVector{R}, c::Integer, result::RowColoringResult
348-
) where {R<:Real}
349-
@compat (; S, Sᵀ, group) = result
345+
A::AbstractMatrix, b::AbstractVector, c::Integer, result::RowColoringResult
346+
)
347+
@compat (; group) = result
348+
S, Sᵀ = result.bg.S2, result.bg.S1
350349
check_same_pattern(A, S)
351350
rvSᵀ = rowvals(Sᵀ)
352351
for i in group[c]
@@ -358,10 +357,9 @@ function decompress_single_color!(
358357
return A
359358
end
360359

361-
function decompress!(
362-
A::SparseMatrixCSC{R}, B::AbstractMatrix{R}, result::RowColoringResult
363-
) where {R<:Real}
364-
@compat (; S, compressed_indices) = result
360+
function decompress!(A::SparseMatrixCSC, B::AbstractMatrix, result::RowColoringResult)
361+
@compat (; compressed_indices) = result
362+
S = result.bg.S2
365363
check_same_pattern(A, S)
366364
nzA = nonzeros(A)
367365
for k in eachindex(nzA, compressed_indices)
@@ -373,15 +371,13 @@ end
373371
## StarSetColoringResult
374372

375373
function decompress!(
376-
A::AbstractMatrix{R},
377-
B::AbstractMatrix{R},
378-
result::StarSetColoringResult,
379-
uplo::Symbol=:F,
380-
) where {R<:Real}
381-
@compat (; S, color, star_set) = result
374+
A::AbstractMatrix, B::AbstractMatrix, result::StarSetColoringResult, uplo::Symbol=:F
375+
)
376+
@compat (; color, star_set) = result
382377
@compat (; star, hub, spokes) = star_set
378+
S = result.ag.S
383379
uplo == :F && check_same_pattern(A, S)
384-
A .= zero(R)
380+
A .= zero(eltype(A))
385381
for i in axes(A, 1)
386382
if !iszero(S[i, i])
387383
A[i, i] = B[i, color[i]]
@@ -403,14 +399,15 @@ function decompress!(
403399
end
404400

405401
function decompress_single_color!(
406-
A::AbstractMatrix{R},
407-
b::AbstractVector{R},
402+
A::AbstractMatrix,
403+
b::AbstractVector,
408404
c::Integer,
409405
result::StarSetColoringResult,
410406
uplo::Symbol=:F,
411-
) where {R<:Real}
412-
@compat (; S, color, group, star_set) = result
407+
)
408+
@compat (; color, group, star_set) = result
413409
@compat (; hub, spokes) = star_set
410+
S = result.ag.S
414411
uplo == :F && check_same_pattern(A, S)
415412
for i in axes(A, 1)
416413
if !iszero(S[i, i]) && color[i] == c
@@ -434,12 +431,10 @@ function decompress_single_color!(
434431
end
435432

436433
function decompress!(
437-
A::SparseMatrixCSC{R},
438-
B::AbstractMatrix{R},
439-
result::StarSetColoringResult,
440-
uplo::Symbol=:F,
441-
) where {R<:Real}
442-
@compat (; S, compressed_indices) = result
434+
A::SparseMatrixCSC, B::AbstractMatrix, result::StarSetColoringResult, uplo::Symbol=:F
435+
)
436+
@compat (; compressed_indices) = result
437+
S = result.ag.S
443438
nzA = nonzeros(A)
444439
if uplo == :F
445440
check_same_pattern(A, S)
@@ -468,13 +463,12 @@ end
468463
# TODO: add method for A::SparseMatrixCSC
469464

470465
function decompress!(
471-
A::AbstractMatrix{R},
472-
B::AbstractMatrix{R},
473-
result::TreeSetColoringResult,
474-
uplo::Symbol=:F,
475-
) where {R<:Real}
476-
@compat (; S, color, vertices_by_tree, reverse_bfs_orders, buffer) = result
466+
A::AbstractMatrix, B::AbstractMatrix, result::TreeSetColoringResult, uplo::Symbol=:F
467+
)
468+
@compat (; color, vertices_by_tree, reverse_bfs_orders, buffer) = result
469+
S = result.ag.S
477470
uplo == :F && check_same_pattern(A, S)
471+
R = eltype(A)
478472
A .= zero(R)
479473

480474
if eltype(buffer) == R
@@ -513,19 +507,19 @@ end
513507
## MatrixInverseColoringResult
514508

515509
function decompress!(
516-
A::AbstractMatrix{R},
517-
B::AbstractMatrix{R},
510+
A::AbstractMatrix,
511+
B::AbstractMatrix,
518512
result::LinearSystemColoringResult,
519513
uplo::Symbol=:F,
520-
) where {R<:Real}
521-
@compat (;
522-
S, color, strict_upper_nonzero_inds, T_factorization, strict_upper_nonzeros_A
523-
) = result
514+
)
515+
@compat (; color, strict_upper_nonzero_inds, T_factorization, strict_upper_nonzeros_A) =
516+
result
517+
S = result.ag.S
524518
uplo == :F && check_same_pattern(A, S)
525519

526520
# TODO: for some reason I cannot use ldiv! with a sparse QR
527521
strict_upper_nonzeros_A = T_factorization \ vec(B)
528-
A .= zero(R)
522+
A .= zero(eltype(A))
529523
for i in axes(A, 1)
530524
if !iszero(S[i, i])
531525
A[i, i] = B[i, color[i]]

src/graph.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,19 @@ end
2424
SparsityPatternCSC(A::SparseMatrixCSC) = SparsityPatternCSC(A.m, A.n, A.colptr, A.rowval)
2525

2626
Base.size(S::SparsityPatternCSC) = (S.m, S.n)
27+
28+
function Base.size(S::SparsityPatternCSC, d::Integer)
29+
if d == 1
30+
return S.m
31+
elseif d == 2
32+
return S.n
33+
else
34+
return 1
35+
end
36+
end
37+
38+
Base.axes(S::SparsityPatternCSC, d::Integer) = Base.OneTo(size(S, d))
39+
2740
SparseArrays.nnz(S::SparsityPatternCSC) = length(S.rowval)
2841
SparseArrays.rowvals(S::SparsityPatternCSC) = S.rowval
2942
SparseArrays.nzrange(S::SparsityPatternCSC, j::Integer) = S.colptr[j]:(S.colptr[j + 1] - 1)
@@ -81,6 +94,15 @@ function Base.transpose(S::SparsityPatternCSC{T}) where {T}
8194
return SparsityPatternCSC{T}(n, m, B_colptr, B_rowval)
8295
end
8396

97+
# copied from SparseArrays.jl
98+
function Base.getindex(S::SparsityPatternCSC, i0::Integer, i1::Integer)
99+
r1 = Int(S.colptr[i1])
100+
r2 = Int(S.colptr[i1 + 1] - 1)
101+
(r1 > r2) && return false
102+
r1 = searchsortedfirst(rowvals(S), i0, r1, r2, Base.Order.Forward)
103+
return ((r1 > r2) || (rowvals(S)[r1] != i0)) ? false : true
104+
end
105+
84106
## Adjacency graph
85107

86108
"""
@@ -109,6 +131,7 @@ struct AdjacencyGraph{T}
109131
S::SparsityPatternCSC{T}
110132
end
111133

134+
AdjacencyGraph(A::AbstractMatrix) = AdjacencyGraph(SparseMatrixCSC(A))
112135
AdjacencyGraph(A::SparseMatrixCSC) = AdjacencyGraph(SparsityPatternCSC(A))
113136

114137
pattern(g::AdjacencyGraph) = g.S
@@ -183,6 +206,10 @@ struct BipartiteGraph{T<:Integer}
183206
S2::SparsityPatternCSC{T}
184207
end
185208

209+
function BipartiteGraph(A::AbstractMatrix; symmetric_pattern::Bool=false)
210+
return BipartiteGraph(SparseMatrixCSC(A); symmetric_pattern)
211+
end
212+
186213
function BipartiteGraph(A::SparseMatrixCSC; symmetric_pattern::Bool=false)
187214
S2 = SparsityPatternCSC(A) # columns to rows
188215
if symmetric_pattern

src/interface.jl

Lines changed: 11 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -180,12 +180,11 @@ function coloring(
180180
decompression_eltype::Type=Float64,
181181
symmetric_pattern::Bool=false,
182182
)
183-
S = convert(SparseMatrixCSC, A)
184183
bg = BipartiteGraph(
185-
S; symmetric_pattern=symmetric_pattern || A isa Union{Symmetric,Hermitian}
184+
A; symmetric_pattern=symmetric_pattern || A isa Union{Symmetric,Hermitian}
186185
)
187186
color = partial_distance2_coloring(bg, Val(2), algo.order)
188-
return ColumnColoringResult(S, color)
187+
return ColumnColoringResult(A, bg, color)
189188
end
190189

191190
function coloring(
@@ -195,12 +194,11 @@ function coloring(
195194
decompression_eltype::Type=Float64,
196195
symmetric_pattern::Bool=false,
197196
)
198-
S = convert(SparseMatrixCSC, A)
199197
bg = BipartiteGraph(
200-
S; symmetric_pattern=symmetric_pattern || A isa Union{Symmetric,Hermitian}
198+
A; symmetric_pattern=symmetric_pattern || A isa Union{Symmetric,Hermitian}
201199
)
202200
color = partial_distance2_coloring(bg, Val(1), algo.order)
203-
return RowColoringResult(S, color)
201+
return RowColoringResult(A, bg, color)
204202
end
205203

206204
function coloring(
@@ -209,10 +207,9 @@ function coloring(
209207
algo::GreedyColoringAlgorithm{:direct};
210208
decompression_eltype::Type=Float64,
211209
)
212-
S = convert(SparseMatrixCSC, A)
213-
ag = AdjacencyGraph(S)
210+
ag = AdjacencyGraph(A)
214211
color, star_set = star_coloring(ag, algo.order)
215-
return StarSetColoringResult(S, color, star_set)
212+
return StarSetColoringResult(A, ag, color, star_set)
216213
end
217214

218215
function coloring(
@@ -221,31 +218,27 @@ function coloring(
221218
algo::GreedyColoringAlgorithm{:substitution};
222219
decompression_eltype::Type=Float64,
223220
)
224-
S = convert(SparseMatrixCSC, A)
225-
ag = AdjacencyGraph(S)
221+
ag = AdjacencyGraph(A)
226222
color, tree_set = acyclic_coloring(ag, algo.order)
227-
return TreeSetColoringResult(S, color, tree_set, decompression_eltype)
223+
return TreeSetColoringResult(A, ag, color, tree_set, decompression_eltype)
228224
end
229225

230226
## ADTypes interface
231227

232228
function ADTypes.column_coloring(A::AbstractMatrix, algo::GreedyColoringAlgorithm)
233-
S = convert(SparseMatrixCSC, A)
234-
bg = BipartiteGraph(S; symmetric_pattern=A isa Union{Symmetric,Hermitian})
229+
bg = BipartiteGraph(A; symmetric_pattern=A isa Union{Symmetric,Hermitian})
235230
color = partial_distance2_coloring(bg, Val(2), algo.order)
236231
return color
237232
end
238233

239234
function ADTypes.row_coloring(A::AbstractMatrix, algo::GreedyColoringAlgorithm)
240-
S = convert(SparseMatrixCSC, A)
241-
bg = BipartiteGraph(S; symmetric_pattern=A isa Union{Symmetric,Hermitian})
235+
bg = BipartiteGraph(A; symmetric_pattern=A isa Union{Symmetric,Hermitian})
242236
color = partial_distance2_coloring(bg, Val(1), algo.order)
243237
return color
244238
end
245239

246240
function ADTypes.symmetric_coloring(A::AbstractMatrix, algo::GreedyColoringAlgorithm)
247-
S = convert(SparseMatrixCSC, A)
248-
ag = AdjacencyGraph(S)
241+
ag = AdjacencyGraph(A)
249242
color, star_set = star_coloring(ag, algo.order)
250243
return color
251244
end

0 commit comments

Comments
 (0)