Skip to content

Commit 92594c4

Browse files
committed
Store CSR indices as additional info
1 parent f7daa2f commit 92594c4

4 files changed

Lines changed: 147 additions & 32 deletions

File tree

ext/SparseMatrixColoringsCUDAExt.jl

Lines changed: 75 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,24 +22,26 @@ function SMC.compress(
2222
return CuMatrix(SMC.compress(SparseMatrixCSC(A), result))
2323
end
2424

25-
## CSC
25+
## CSC Result
2626

2727
function SMC.ColumnColoringResult(
2828
A::CuSparseMatrixCSC, bg::SMC.BipartiteGraph{T}, color::Vector{<:Integer}
2929
) where {T<:Integer}
30-
A_cpu = SparseMatrixCSC(A)
31-
result_cpu = SMC.ColumnColoringResult(A_cpu, bg, color)
32-
compressed_indices = CuVector(result_cpu.compressed_indices)
33-
return SMC.ColumnColoringResult(A, bg, color, result_cpu.group, compressed_indices)
30+
group = SMC.group_by_color(T, color)
31+
compressed_indices = SMC.column_csc_indices(bg, color)
32+
additional_info = (; compressed_indices_gpu_csc=CuVector(compressed_indices))
33+
return SMC.ColumnColoringResult(
34+
A, bg, color, group, compressed_indices, additional_info
35+
)
3436
end
3537

3638
function SMC.RowColoringResult(
3739
A::CuSparseMatrixCSC, bg::SMC.BipartiteGraph{T}, color::Vector{<:Integer}
3840
) where {T<:Integer}
39-
A_cpu = SparseMatrixCSC(A)
40-
result_cpu = SMC.RowColoringResult(A_cpu, bg, color)
41-
compressed_indices = CuVector(result_cpu.compressed_indices)
42-
return SMC.RowColoringResult(A, bg, color, result_cpu.group, compressed_indices)
41+
group = SMC.group_by_color(T, color)
42+
compressed_indices = SMC.row_csc_indices(bg, color)
43+
additional_info = (; compressed_indices_gpu_csc=CuVector(compressed_indices))
44+
return SMC.RowColoringResult(A, bg, color, group, compressed_indices, additional_info)
4345
end
4446

4547
function SMC.StarSetColoringResult(
@@ -48,12 +50,54 @@ function SMC.StarSetColoringResult(
4850
color::Vector{<:Integer},
4951
star_set::SMC.StarSet{<:Integer},
5052
) where {T<:Integer}
51-
A_cpu = SparseMatrixCSC(A)
52-
result_cpu = SMC.StarSetColoringResult(A_cpu, ag, color, star_set)
53-
compressed_indices = CuVector(result_cpu.compressed_indices)
54-
return SMC.StarSetColoringResult(A, ag, color, result_cpu.group, compressed_indices)
53+
group = SMC.group_by_color(T, color)
54+
compressed_indices = SMC.star_csc_indices(ag, color, star_set)
55+
additional_info = (; compressed_indices_gpu_csc=CuVector(compressed_indices))
56+
return SMC.StarSetColoringResult(
57+
A, ag, color, group, compressed_indices, additional_info
58+
)
59+
end
60+
61+
## CSR Result
62+
63+
function SMC.ColumnColoringResult(
64+
A::CuSparseMatrixCSR, bg::SMC.BipartiteGraph{T}, color::Vector{<:Integer}
65+
) where {T<:Integer}
66+
group = SMC.group_by_color(T, color)
67+
compressed_indices = SMC.column_csc_indices(bg, color)
68+
compressed_indices_csr = SMC.column_csr_indices(bg, color)
69+
additional_info = (; compressed_indices_gpu_csr=CuVector(compressed_indices_csr))
70+
return SMC.ColumnColoringResult(
71+
A, bg, color, group, compressed_indices, additional_info
72+
)
73+
end
74+
75+
function SMC.RowColoringResult(
76+
A::CuSparseMatrixCSR, bg::SMC.BipartiteGraph{T}, color::Vector{<:Integer}
77+
) where {T<:Integer}
78+
group = SMC.group_by_color(T, color)
79+
compressed_indices = SMC.row_csc_indices(bg, color)
80+
compressed_indices_csr = SMC.row_csr_indices(bg, color)
81+
additional_info = (; compressed_indices_gpu_csr=CuVector(compressed_indices_csr))
82+
return SMC.RowColoringResult(A, bg, color, group, compressed_indices, additional_info)
83+
end
84+
85+
function SMC.StarSetColoringResult(
86+
A::CuSparseMatrixCSR,
87+
ag::SMC.AdjacencyGraph{T},
88+
color::Vector{<:Integer},
89+
star_set::SMC.StarSet{<:Integer},
90+
) where {T<:Integer}
91+
group = SMC.group_by_color(T, color)
92+
compressed_indices = SMC.star_csc_indices(ag, color, star_set)
93+
additional_info = (; compressed_indices_gpu_csr=CuVector(compressed_indices))
94+
return SMC.StarSetColoringResult(
95+
A, ag, color, group, compressed_indices, additional_info
96+
)
5597
end
5698

99+
## Decompression
100+
57101
function update_nzval_from_matrix!(
58102
nzVal::AbstractVector, B::AbstractMatrix, compressed_indices::AbstractVector{<:Integer}
59103
)
@@ -70,14 +114,30 @@ for R in (:ColumnColoringResult, :RowColoringResult, :StarSetColoringResult)
70114
@eval function SMC.decompress!(
71115
A::CuSparseMatrixCSC, B::CuMatrix, result::SMC.$R{<:CuSparseMatrixCSC}
72116
)
117+
compressed_indices = result.additional_info.compressed_indices_gpu_csc
118+
A.nnz == 0 && return A
119+
kernel = @cuda launch = false update_nzval_from_matrix!(
120+
A.nzVal, B, compressed_indices
121+
)
122+
config = launch_configuration(kernel.fun)
123+
threads = min(A.nnz, config.threads)
124+
blocks = cld(A.nnz, threads)
125+
kernel(A.nzVal, B, compressed_indices; threads, blocks)
126+
return A
127+
end
128+
129+
@eval function SMC.decompress!(
130+
A::CuSparseMatrixCSR, B::CuMatrix, result::SMC.$R{<:CuSparseMatrixCSR}
131+
)
132+
compressed_indices = result.additional_info.compressed_indices_gpu_csr
73133
A.nnz == 0 && return A
74134
kernel = @cuda launch = false update_nzval_from_matrix!(
75-
A.nzVal, B, result.compressed_indices
135+
A.nzVal, B, compressed_indices
76136
)
77137
config = launch_configuration(kernel.fun)
78138
threads = min(A.nnz, config.threads)
79139
blocks = cld(A.nnz, threads)
80-
kernel(A.nzVal, B, result.compressed_indices; threads, blocks)
140+
kernel(A.nzVal, B, compressed_indices; threads, blocks)
81141
return A
82142
end
83143
end

src/graph.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -345,8 +345,6 @@ end
345345

346346
Base.eltype(::BipartiteGraph{T}) where {T} = T
347347

348-
Base.transpose(bg::BipartiteGraph) = BipartiteGraph(bg.S2, bg.S1)
349-
350348
function BipartiteGraph(A::AbstractMatrix; symmetric_pattern::Bool=false)
351349
return BipartiteGraph(SparseMatrixCSC(A); symmetric_pattern)
352350
end

src/result.jl

Lines changed: 64 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ struct ColumnColoringResult{
152152
CT<:AbstractVector{T},
153153
GT<:AbstractGroups{T},
154154
VT<:AbstractVector{T},
155+
A,
155156
} <: AbstractColoringResult{:nonsymmetric,:column,:direct}
156157
"matrix that was colored"
157158
A::M
@@ -161,15 +162,22 @@ struct ColumnColoringResult{
161162
color::CT
162163
"color groups for columns or rows (depending on `partition`)"
163164
group::GT
164-
"flattened indices mapping the compressed matrix `B` to the uncompressed matrix `A`. When `A isa SparseMatrixCSC`, they satisfy `nonzeros(A)[k] = vec(B)[compressed_indices[k]]`."
165+
"flattened indices mapping the compressed matrix `B` to the uncompressed matrix `A` when `A isa SparseMatrixCSC`. They satisfy `nonzeros(A)[k] = vec(B)[compressed_indices[k]]`"
165166
compressed_indices::VT
167+
"optional data used for decompressing into specific matrix types"
168+
additional_info::A
166169
end
167170

168171
function ColumnColoringResult(
169172
A::AbstractMatrix, bg::BipartiteGraph{T}, color::Vector{<:Integer}
170173
) where {T<:Integer}
171-
S = bg.S2
172174
group = group_by_color(T, color)
175+
compressed_indices = column_csc_indices(bg, color)
176+
return ColumnColoringResult(A, bg, color, group, compressed_indices, nothing)
177+
end
178+
179+
function column_csc_indices(bg::BipartiteGraph{T}, color::Vector{<:Integer}) where {T}
180+
S = bg.S2
173181
n = size(S, 1)
174182
rv = rowvals(S)
175183
compressed_indices = zeros(T, nnz(S))
@@ -181,7 +189,23 @@ function ColumnColoringResult(
181189
compressed_indices[k] = (c - 1) * n + i
182190
end
183191
end
184-
return ColumnColoringResult(A, bg, color, group, compressed_indices)
192+
return compressed_indices
193+
end
194+
195+
function column_csr_indices(bg::BipartiteGraph{T}, color::Vector{<:Integer}) where {T}
196+
Sᵀ = bg.S1 # CSC storage of transpose(A)
197+
n = size(Sᵀ, 2)
198+
rv = rowvals(Sᵀ)
199+
compressed_indices = zeros(T, nnz(Sᵀ))
200+
for i in axes(Sᵀ, 2)
201+
for k in nzrange(Sᵀ, i)
202+
j = rv[k]
203+
c = color[j]
204+
# A[i, j] = B[i, c]
205+
compressed_indices[k] = (c - 1) * n + i
206+
end
207+
end
208+
return compressed_indices
185209
end
186210

187211
"""
@@ -206,20 +230,27 @@ struct RowColoringResult{
206230
CT<:AbstractVector{T},
207231
GT<:AbstractGroups{T},
208232
VT<:AbstractVector{T},
233+
A,
209234
} <: AbstractColoringResult{:nonsymmetric,:row,:direct}
210235
A::M
211236
bg::G
212237
color::CT
213238
group::GT
214239
compressed_indices::VT
240+
additional_info::A
215241
end
216242

217243
function RowColoringResult(
218244
A::AbstractMatrix, bg::BipartiteGraph{T}, color::Vector{<:Integer}
219245
) where {T<:Integer}
220-
S = bg.S2
221246
group = group_by_color(T, color)
222-
C = length(group) # ncolors
247+
compressed_indices = row_csc_indices(bg, color)
248+
return RowColoringResult(A, bg, color, group, compressed_indices, nothing)
249+
end
250+
251+
function row_csc_indices(bg::BipartiteGraph{T}, color::Vector{<:Integer}) where {T}
252+
S = bg.S2
253+
C = maximum(color) # ncolors
223254
rv = rowvals(S)
224255
compressed_indices = zeros(T, nnz(S))
225256
for j in axes(S, 2)
@@ -230,7 +261,23 @@ function RowColoringResult(
230261
compressed_indices[k] = (j - 1) * C + c
231262
end
232263
end
233-
return RowColoringResult(A, bg, color, group, compressed_indices)
264+
return compressed_indices
265+
end
266+
267+
function row_csr_indices(bg::BipartiteGraph{T}, color::Vector{<:Integer}) where {T}
268+
Sᵀ = bg.S1 # CSC storage of transpose(A)
269+
C = maximum(color) # ncolors
270+
rv = rowvals(Sᵀ)
271+
compressed_indices = zeros(T, nnz(Sᵀ))
272+
for i in axes(Sᵀ, 2)
273+
for k in nzrange(Sᵀ, i)
274+
j = rv[k]
275+
c = color[i]
276+
# A[i, j] = B[c, j]
277+
compressed_indices[k] = (j - 1) * C + c
278+
end
279+
end
280+
return compressed_indices
234281
end
235282

236283
"""
@@ -255,12 +302,14 @@ struct StarSetColoringResult{
255302
CT<:AbstractVector{T},
256303
GT<:AbstractGroups{T},
257304
VT<:AbstractVector{T},
305+
A,
258306
} <: AbstractColoringResult{:symmetric,:column,:direct}
259307
A::M
260308
ag::G
261309
color::CT
262310
group::GT
263311
compressed_indices::VT
312+
additional_info::A
264313
end
265314

266315
function StarSetColoringResult(
@@ -269,11 +318,18 @@ function StarSetColoringResult(
269318
color::Vector{<:Integer},
270319
star_set::StarSet{<:Integer},
271320
) where {T<:Integer}
321+
group = group_by_color(T, color)
322+
compressed_indices = star_csc_indices(ag, color, star_set)
323+
return StarSetColoringResult(A, ag, color, group, compressed_indices, nothing)
324+
end
325+
326+
function star_csc_indices(
327+
ag::AdjacencyGraph{T}, color::Vector{<:Integer}, star_set
328+
) where {T}
272329
(; star, hub) = star_set
273330
S = pattern(ag)
274331
edge_to_index = edge_indices(ag)
275332
n = S.n
276-
group = group_by_color(T, color)
277333
rvS = rowvals(S)
278334
compressed_indices = zeros(T, nnz(S)) # needs to be independent from the storage in the graph, in case the graph gets reused
279335
for j in axes(S, 2)
@@ -302,8 +358,7 @@ function StarSetColoringResult(
302358
end
303359
end
304360
end
305-
306-
return StarSetColoringResult(A, ag, color, group, compressed_indices)
361+
return compressed_indices
307362
end
308363

309364
"""

test/cuda.jl

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ using SparseMatrixColorings
55
using StableRNGs
66
using Test
77

8+
include("utils.jl")
9+
810
rng = StableRNG(63)
911

1012
asymmetric_params = vcat(
@@ -19,32 +21,32 @@ symmetric_params = vcat(
1921
[(100, p) for p in (0.01:0.02:0.05)],
2022
)
2123

22-
@testset "Column coloring & decompression" begin
24+
@testset verbose = true "Column coloring & decompression" begin
2325
problem = ColoringProblem(; structure=:nonsymmetric, partition=:column)
2426
algo = GreedyColoringAlgorithm(; decompression=:direct)
25-
@testset for T in (CuSparseMatrixCSC,)
27+
@testset for T in (CuSparseMatrixCSC, CuSparseMatrixCSR)
2628
@testset "$((; m, n, p))" for (m, n, p) in asymmetric_params
2729
A0 = T(sprand(rng, m, n, p))
2830
test_coloring_decompression(A0, problem, algo; gpu=true)
2931
end
3032
end
3133
end;
3234

33-
@testset "Row coloring & decompression" begin
35+
@testset verbose = true "Row coloring & decompression" begin
3436
problem = ColoringProblem(; structure=:nonsymmetric, partition=:row)
3537
algo = GreedyColoringAlgorithm(; decompression=:direct)
36-
@testset for T in (CuSparseMatrixCSC,)
38+
@testset for T in (CuSparseMatrixCSC, CuSparseMatrixCSR)
3739
@testset "$((; m, n, p))" for (m, n, p) in asymmetric_params
3840
A0 = T(sprand(rng, m, n, p))
3941
test_coloring_decompression(A0, problem, algo; gpu=true)
4042
end
4143
end
4244
end;
4345

44-
@testset "Symmetric coloring & direct decompression" begin
46+
@testset verbose = true "Symmetric coloring & direct decompression" begin
4547
problem = ColoringProblem(; structure=:symmetric, partition=:column)
4648
algo = GreedyColoringAlgorithm(; postprocessing=false, decompression=:direct)
47-
@testset for T in (CuSparseMatrixCSC,)
49+
@testset for T in (CuSparseMatrixCSC, CuSparseMatrixCSR)
4850
@testset "$((; n, p))" for (n, p) in symmetric_params
4951
A0 = T(sparse(Symmetric(sprand(rng, n, n, p))))
5052
test_coloring_decompression(A0, problem, algo; gpu=true)

0 commit comments

Comments
 (0)