Skip to content

Commit 77002f4

Browse files
committed
Decompression for CuSparseMatrixCSC
1 parent f7acc3f commit 77002f4

7 files changed

Lines changed: 183 additions & 25 deletions

File tree

Project.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "SparseMatrixColorings"
22
uuid = "0a514795-09f3-496d-8182-132a7b665d35"
33
authors = ["Guillaume Dalle", "Alexis Montoison"]
4-
version = "0.4.20"
4+
version = "0.4.21"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -12,15 +12,18 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1212
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1313

1414
[weakdeps]
15+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
1516
CliqueTrees = "60701a23-6482-424a-84db-faee86b9b1f8"
1617
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
1718

1819
[extensions]
20+
SparseMatrixColoringsCUDAExt = "CUDA"
1921
SparseMatrixColoringsCliqueTreesExt = "CliqueTrees"
2022
SparseMatrixColoringsColorsExt = "Colors"
2123

2224
[compat]
2325
ADTypes = "1.2.1"
26+
CUDA = "5.8.2"
2427
CliqueTrees = "1"
2528
Colors = "0.12.11, 0.13"
2629
DocStringExtensions = "0.8,0.9"
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
module SparseMatrixColoringsCUDAExt
2+
3+
import SparseMatrixColorings as SMC
4+
using SparseArrays: SparseMatrixCSC, rowvals, nnz, nzrange
5+
using CUDA: CuVector, CuMatrix
6+
using CUDA.CUSPARSE: AbstractCuSparseMatrix, CuSparseMatrixCSC, CuSparseMatrixCSR
7+
8+
SMC.matrix_versions(A::AbstractCuSparseMatrix) = (A,)
9+
10+
## Compression (slow, through CPU)
11+
12+
function SMC.compress(
13+
A::AbstractCuSparseMatrix, result::SMC.AbstractColoringResult{structure,:column}
14+
) where {structure}
15+
return CuMatrix(SMC.compress(SparseMatrixCSC(A), result))
16+
end
17+
18+
function SMC.compress(
19+
A::AbstractCuSparseMatrix, result::SMC.AbstractColoringResult{structure,:row}
20+
) where {structure}
21+
return CuMatrix(SMC.compress(SparseMatrixCSC(A), result))
22+
end
23+
24+
## CSC
25+
26+
function SMC.ColumnColoringResult(
27+
A::CuSparseMatrixCSC, bg::SMC.BipartiteGraph{T}, color::Vector{<:Integer}
28+
) where {T<:Integer}
29+
A_cpu = SparseMatrixCSC(A)
30+
result_cpu = SMC.ColumnColoringResult(A_cpu, bg, color)
31+
compressed_indices = CuVector(result_cpu.compressed_indices)
32+
return SMC.ColumnColoringResult(A, bg, color, result_cpu.group, compressed_indices)
33+
end
34+
35+
function SMC.RowColoringResult(
36+
A::CuSparseMatrixCSC, bg::SMC.BipartiteGraph{T}, color::Vector{<:Integer}
37+
) where {T<:Integer}
38+
A_cpu = SparseMatrixCSC(A)
39+
result_cpu = SMC.RowColoringResult(A_cpu, bg, color)
40+
compressed_indices = CuVector(result_cpu.compressed_indices)
41+
return SMC.RowColoringResult(A, bg, color, result_cpu.group, compressed_indices)
42+
end
43+
44+
function SMC.StarSetColoringResult(
45+
A::CuSparseMatrixCSC,
46+
ag::SMC.AdjacencyGraph{T},
47+
color::Vector{<:Integer},
48+
star_set::SMC.StarSet{<:Integer},
49+
) where {T<:Integer}
50+
A_cpu = SparseMatrixCSC(A)
51+
result_cpu = SMC.StarSetColoringResult(A_cpu, ag, color, star_set)
52+
compressed_indices = CuVector(result_cpu.compressed_indices)
53+
return SMC.StarSetColoringResult(A, ag, color, result_cpu.group, compressed_indices)
54+
end
55+
56+
function SMC.decompress!(
57+
A::CuSparseMatrixCSC, B::CuMatrix, result::SMC.ColumnColoringResult{<:CuSparseMatrixCSC}
58+
)
59+
A.nzVal .= getindex.(Ref(B), result.compressed_indices)
60+
return A
61+
end
62+
63+
function SMC.decompress!(
64+
A::CuSparseMatrixCSC, B::CuMatrix, result::SMC.RowColoringResult{<:CuSparseMatrixCSC}
65+
)
66+
A.nzVal .= getindex.(Ref(B), result.compressed_indices)
67+
return A
68+
end
69+
70+
function SMC.decompress!(
71+
A::CuSparseMatrixCSC,
72+
B::CuMatrix,
73+
result::SMC.StarSetColoringResult{<:CuSparseMatrixCSC},
74+
)
75+
A.nzVal .= getindex.(Ref(B), result.compressed_indices)
76+
return A
77+
end
78+
79+
end

src/graph.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ end
100100
Return a [`SparsityPatternCSC`](@ref) corresponding to the matrix `[0 Aᵀ; A 0]`, with a minimum of allocations.
101101
"""
102102
function bidirectional_pattern(A::AbstractMatrix; symmetric_pattern::Bool)
103-
bidirectional_pattern(SparsityPatternCSC(SparseMatrixCSC(A)); symmetric_pattern)
103+
return bidirectional_pattern(SparsityPatternCSC(SparseMatrixCSC(A)); symmetric_pattern)
104104
end
105105

106106
function bidirectional_pattern(S::SparsityPatternCSC{T}; symmetric_pattern::Bool) where {T}
@@ -345,6 +345,8 @@ end
345345

346346
Base.eltype(::BipartiteGraph{T}) where {T} = T
347347

348+
Base.transpose(bg::BipartiteGraph) = BipartiteGraph(bg.S2, bg.S1)
349+
348350
function BipartiteGraph(A::AbstractMatrix; symmetric_pattern::Bool=false)
349351
return BipartiteGraph(SparseMatrixCSC(A); symmetric_pattern)
350352
end

src/matrices.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ Return various versions of the same matrix:
1010
1111
Used for internal testing.
1212
"""
13-
function matrix_versions(A)
13+
function matrix_versions(A::AbstractMatrix)
1414
A_dense = Matrix(A)
1515
A_sparse = sparse(A)
1616
versions = [

src/result.jl

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -146,18 +146,23 @@ $TYPEDFIELDS
146146
- [`AbstractColoringResult`](@ref)
147147
"""
148148
struct ColumnColoringResult{
149-
M<:AbstractMatrix,T<:Integer,G<:BipartiteGraph{T},GT<:AbstractGroups{T}
149+
M<:AbstractMatrix,
150+
T<:Integer,
151+
G<:BipartiteGraph{T},
152+
CT<:AbstractVector{T},
153+
GT<:AbstractGroups{T},
154+
VT<:AbstractVector{T},
150155
} <: AbstractColoringResult{:nonsymmetric,:column,:direct}
151156
"matrix that was colored"
152157
A::M
153158
"bipartite graph that was used for coloring"
154159
bg::G
155160
"one integer color for each column or row (depending on `partition`)"
156-
color::Vector{T}
161+
color::CT
157162
"color groups for columns or rows (depending on `partition`)"
158163
group::GT
159-
"flattened indices mapping the compressed matrix `B` to the uncompressed matrix `A` when `A isa SparseMatrixCSC`. They satisfy `nonzeros(A)[k] = vec(B)[compressed_indices[k]]`"
160-
compressed_indices::Vector{T}
164+
"flattened indices mapping the compressed matrix `B` to the uncompressed matrix `A`. When `A isa SparseMatrixCSC`, they satisfy `nonzeros(A)[k] = vec(B)[compressed_indices[k]]`."
165+
compressed_indices::VT
161166
end
162167

163168
function ColumnColoringResult(
@@ -195,13 +200,18 @@ $TYPEDFIELDS
195200
- [`AbstractColoringResult`](@ref)
196201
"""
197202
struct RowColoringResult{
198-
M<:AbstractMatrix,T<:Integer,G<:BipartiteGraph{T},GT<:AbstractGroups{T}
203+
M<:AbstractMatrix,
204+
T<:Integer,
205+
G<:BipartiteGraph{T},
206+
CT<:AbstractVector{T},
207+
GT<:AbstractGroups{T},
208+
VT<:AbstractVector{T},
199209
} <: AbstractColoringResult{:nonsymmetric,:row,:direct}
200210
A::M
201211
bg::G
202-
color::Vector{T}
212+
color::CT
203213
group::GT
204-
compressed_indices::Vector{T}
214+
compressed_indices::VT
205215
end
206216

207217
function RowColoringResult(
@@ -239,13 +249,18 @@ $TYPEDFIELDS
239249
- [`AbstractColoringResult`](@ref)
240250
"""
241251
struct StarSetColoringResult{
242-
M<:AbstractMatrix,T<:Integer,G<:AdjacencyGraph{T},GT<:AbstractGroups{T}
252+
M<:AbstractMatrix,
253+
T<:Integer,
254+
G<:AdjacencyGraph{T},
255+
CT<:AbstractVector{T},
256+
GT<:AbstractGroups{T},
257+
VT<:AbstractVector{T},
243258
} <: AbstractColoringResult{:symmetric,:column,:direct}
244259
A::M
245260
ag::G
246-
color::Vector{T}
261+
color::CT
247262
group::GT
248-
compressed_indices::Vector{T}
263+
compressed_indices::VT
249264
end
250265

251266
function StarSetColoringResult(

test/cuda.jl

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
using CUDA.CUSPARSE: CuSparseMatrixCSC, CuSparseMatrixCSR
2+
using LinearAlgebra
3+
using SparseArrays
4+
using SparseMatrixColorings
5+
using StableRNGs
6+
using Test
7+
8+
rng = StableRNG(63)
9+
10+
asymmetric_params = vcat(
11+
[(10, 20, p) for p in (0.0:0.2:0.5)],
12+
[(20, 10, p) for p in (0.0:0.2:0.5)],
13+
[(100, 200, p) for p in (0.01:0.02:0.05)],
14+
[(200, 100, p) for p in (0.01:0.02:0.05)],
15+
)
16+
17+
symmetric_params = vcat(
18+
[(10, p) for p in (0.0:0.2:0.5)], #
19+
[(100, p) for p in (0.01:0.02:0.05)],
20+
)
21+
22+
@testset "Column coloring & decompression" begin
23+
problem = ColoringProblem(; structure=:nonsymmetric, partition=:column)
24+
algo = GreedyColoringAlgorithm(; decompression=:direct)
25+
@testset for T in (CuSparseMatrixCSC,)
26+
@testset "$((; m, n, p))" for (m, n, p) in asymmetric_params
27+
A0 = T(sprand(rng, m, n, p))
28+
test_coloring_decompression(A0, problem, algo; gpu=true)
29+
end
30+
end
31+
end;
32+
33+
@testset "Row coloring & decompression" begin
34+
problem = ColoringProblem(; structure=:nonsymmetric, partition=:row)
35+
algo = GreedyColoringAlgorithm(; decompression=:direct)
36+
@testset for T in (CuSparseMatrixCSC,)
37+
@testset "$((; m, n, p))" for (m, n, p) in asymmetric_params
38+
A0 = T(sprand(rng, m, n, p))
39+
test_coloring_decompression(A0, problem, algo; gpu=true)
40+
end
41+
end
42+
end;
43+
44+
@testset "Symmetric coloring & direct decompression" begin
45+
problem = ColoringProblem(; structure=:symmetric, partition=:column)
46+
algo = GreedyColoringAlgorithm(; postprocessing=false, decompression=:direct)
47+
@testset for T in (CuSparseMatrixCSC,)
48+
@testset "$((; n, p))" for (n, p) in symmetric_params
49+
A0 = T(sparse(Symmetric(sprand(rng, n, n, p))))
50+
test_coloring_decompression(A0, problem, algo; gpu=true)
51+
end
52+
end
53+
end;

test/utils.jl

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ function test_coloring_decompression(
2222
B0=nothing,
2323
color0=nothing,
2424
test_fast=false,
25+
gpu=false,
2526
) where {structure,partition,decompression}
2627
color_vec = Vector{Int}[]
2728
@testset "$(typeof(A))" for A in matrix_versions(A0)
@@ -60,6 +61,17 @@ function test_coloring_decompression(
6061
!isnothing(B0) && @test B == B0
6162
end
6263

64+
@testset "Full decompression" begin
65+
@test decompress(B, result) A0
66+
@test decompress(B, result) A0 # check result wasn't modified
67+
@test decompress!(respectful_similar(A, eltype(B)), B, result) A0
68+
@test decompress!(respectful_similar(A, eltype(B)), B, result) A0
69+
end
70+
71+
if gpu
72+
continue
73+
end
74+
6375
@testset "Recoverability" begin
6476
# TODO: find tests for recoverability for substitution decompression
6577
if decompression == :direct
@@ -81,13 +93,6 @@ function test_coloring_decompression(
8193
end
8294
end
8395

84-
@testset "Full decompression" begin
85-
@test decompress(B, result) A0
86-
@test decompress(B, result) A0 # check result wasn't modified
87-
@test decompress!(respectful_similar(A, eltype(B)), B, result) A0
88-
@test decompress!(respectful_similar(A, eltype(B)), B, result) A0
89-
end
90-
9196
@testset "Single-color decompression" begin
9297
if decompression == :direct # TODO: implement for :substitution too
9398
A2 = respectful_similar(A, eltype(B))
@@ -194,11 +199,6 @@ function test_bicoloring_decompression(
194199
end
195200
end
196201

197-
if decompression == :direct
198-
@testset "Recoverability" begin
199-
@test structurally_biorthogonal(A0, row_color, column_color)
200-
end
201-
end
202202
@testset "Full decompression" begin
203203
@test decompress(Br, Bc, result) A0
204204
@test decompress(Br, Bc, result) A0 # check result wasn't modified
@@ -209,6 +209,12 @@ function test_bicoloring_decompression(
209209
respectful_similar(A, promote_eltype(Br, Bc)), Br, Bc, result
210210
) A0
211211
end
212+
213+
if decompression == :direct
214+
@testset "Recoverability" begin
215+
@test structurally_biorthogonal(A0, row_color, column_color)
216+
end
217+
end
212218
end
213219
end
214220

0 commit comments

Comments
 (0)