Skip to content

Commit 1ce9fb8

Browse files
amontoisongdalle
andauthored
Improve how we build A_and_Aᵀ (#156)
* Improve how we build A_and_Aᵀ * Update src/interface.jl * Add tests * Doc --------- Co-authored-by: Guillaume Dalle <22795598+gdalle@users.noreply.github.com>
1 parent 50f3f23 commit 1ce9fb8

4 files changed

Lines changed: 96 additions & 13 deletions

File tree

docs/src/dev.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ SparseMatrixColorings.BipartiteGraph
1616
SparseMatrixColorings.vertices
1717
SparseMatrixColorings.neighbors
1818
transpose
19+
SparseMatrixColorings.bidirectional_pattern
1920
```
2021

2122
## Low-level coloring

src/graph.jl

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ Copied from `SparseMatrixCSC`:
1414
- `colptr::Vector{Ti}`: column `j` is in `colptr[j]:(colptr[j+1]-1)`
1515
- `rowval::Vector{Ti}`: row indices of stored values
1616
"""
17-
struct SparsityPatternCSC{Ti<:Integer}
17+
struct SparsityPatternCSC{Ti<:Integer} <: AbstractMatrix{Bool}
1818
m::Int
1919
n::Int
2020
colptr::Vector{Ti}
@@ -93,6 +93,76 @@ function Base.getindex(S::SparsityPatternCSC, i0::Integer, i1::Integer)
9393
return ((r1 > r2) || (rowvals(S)[r1] != i0)) ? false : true
9494
end
9595

96+
"""
97+
bidirectional_pattern(A::AbstractMatrix; symmetric_pattern::Bool)
98+
99+
Return a [`SparsityPatternCSC`](@ref) corresponding to the matrix `[0 Aᵀ; A 0]`, with a minimum of allocations.
100+
"""
101+
bidirectional_pattern(A::AbstractMatrix; symmetric_pattern) =
102+
bidirectional_pattern(SparsityPatternCSC(SparseMatrixCSC(A)); symmetric_pattern)
103+
104+
function bidirectional_pattern(S::SparsityPatternCSC; symmetric_pattern)
105+
m, n = size(S)
106+
p = m + n
107+
nnzS = nnz(S)
108+
rowval = Vector{Int}(undef, 2 * nnzS)
109+
colptr = zeros(Int, p + 1)
110+
111+
# Update rowval and colptr for the block A
112+
for i in 1:nnzS
113+
rowval[i] = S.rowval[i] + n
114+
end
115+
for j in 1:n
116+
colptr[j] = S.colptr[j]
117+
end
118+
119+
# Update rowval and colptr for the block Aᵀ
120+
if symmetric_pattern
121+
# We use the sparsity pattern of A for Aᵀ
122+
for i in 1:nnzS
123+
rowval[nnzS + i] = S.rowval[i]
124+
end
125+
# m and n are identical because symmetric_pattern is true
126+
for j in 1:m
127+
colptr[n + j] = nnzS + S.colptr[j]
128+
end
129+
colptr[p + 1] = 2 * nnzS + 1
130+
else
131+
# We need to determine the sparsity pattern of Aᵀ
132+
# We adapt the code of transpose(SparsityPatternCSC) in graph.jl
133+
for k in 1:nnzS
134+
i = S.rowval[k]
135+
colptr[n + i] += 1
136+
end
137+
138+
counter = 1
139+
for col in (n + 1):p
140+
nnz_col = colptr[col]
141+
colptr[col] = counter
142+
counter += nnz_col
143+
end
144+
145+
for j in 1:n
146+
for index in S.colptr[j]:(S.colptr[j + 1] - 1)
147+
i = S.rowval[index]
148+
pos = colptr[n + i]
149+
rowval[nnzS + pos] = j
150+
colptr[n + i] += 1
151+
end
152+
end
153+
154+
colptr[p + 1] = nnzS + counter
155+
for col in p:-1:(n + 2)
156+
colptr[col] = nnzS + colptr[col - 1]
157+
end
158+
colptr[n + 1] = nnzS + 1
159+
end
160+
161+
# Create the SparsityPatternCSC of the augmented adjacency matrix
162+
S_and_Sᵀ = SparsityPatternCSC{Int}(p, p, colptr, rowval)
163+
return S_and_Sᵀ
164+
end
165+
96166
## Adjacency graph
97167

98168
"""

src/interface.jl

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -234,18 +234,9 @@ function coloring(
234234
decompression_eltype::Type{R}=Float64,
235235
symmetric_pattern::Bool=false,
236236
) where {decompression,R}
237-
m, n = size(A)
238-
T = eltype(A)
239-
Aᵀ = if symmetric_pattern || A isa Union{Symmetric,Hermitian}
240-
A
241-
else
242-
transpose(A)
243-
end # TODO: fuse with next step?
244-
A_and_Aᵀ = [
245-
spzeros(T, n, n) SparseMatrixCSC(Aᵀ)
246-
SparseMatrixCSC(A) spzeros(T, m, m)
247-
] # TODO: slow
237+
A_and_Aᵀ = bidirectional_pattern(A; symmetric_pattern)
248238
ag = AdjacencyGraph(A_and_Aᵀ; has_diagonal=false)
239+
249240
if decompression == :direct
250241
color, star_set = star_coloring(ag, algo.order; postprocessing=algo.postprocessing)
251242
symmetric_result = StarSetColoringResult(A_and_Aᵀ, ag, color, star_set)

test/graph.jl

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ using SparseMatrixColorings:
44
SparsityPatternCSC,
55
AdjacencyGraph,
66
BipartiteGraph,
7+
bidirectional_pattern,
78
degree,
89
degree_dist2,
910
nb_vertices,
@@ -16,13 +17,33 @@ using Test
1617
@testset "SparsityPatternCSC" begin
1718
@testset "Transpose" begin
1819
@test all(1:1000) do _
19-
A = sprand(rand(100:1000), rand(100:1000), 0.1)
20+
m, n = rand(100:1000), rand(100:1000)
21+
p = 0.05 * rand()
22+
A = sprand(m, n, p)
2023
S = SparsityPatternCSC(A)
2124
Sᵀ = transpose(S)
2225
Sᵀ_true = SparsityPatternCSC(sparse(transpose(A)))
2326
Sᵀ.colptr == Sᵀ_true.colptr && Sᵀ.rowval == Sᵀ_true.rowval
2427
end
2528
end
29+
@testset "Bidirectional" begin
30+
@test all(1:1000) do _
31+
m, n = rand(100:1000), rand(100:1000)
32+
p = 0.05 * rand()
33+
A = sprand(Bool, m, n, p)
34+
A_and_Aᵀ = [spzeros(Bool, n, n) transpose(A); A spzeros(Bool, m, m)]
35+
S_and_Sᵀ = bidirectional_pattern(A; symmetric_pattern=false)
36+
S_and_Sᵀ.colptr == A_and_Aᵀ.colptr && S_and_Sᵀ.rowval == A_and_Aᵀ.rowval
37+
end
38+
@test all(1:1000) do _
39+
m = rand(100:1000)
40+
p = 0.05 * rand()
41+
A = sparse(Symmetric(sprand(Bool, m, m, p)))
42+
A_and_Aᵀ = [spzeros(Bool, m, m) transpose(A); A spzeros(Bool, m, m)]
43+
S_and_Sᵀ = bidirectional_pattern(A; symmetric_pattern=true)
44+
S_and_Sᵀ.colptr == A_and_Aᵀ.colptr && S_and_Sᵀ.rowval == A_and_Aᵀ.rowval
45+
end
46+
end
2647
@testset "size" begin
2748
A = spzeros(10, 20)
2849
S = SparsityPatternCSC(A)

0 commit comments

Comments
 (0)