Skip to content

Commit 2592da0

Browse files
committed
Generic compression and result
1 parent 487493c commit 2592da0

File tree

4 files changed

+101
-84
lines changed

4 files changed

+101
-84
lines changed

Project.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1515
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
1616
CliqueTrees = "60701a23-6482-424a-84db-faee86b9b1f8"
1717
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
18+
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
1819
JuMP = "4076af6c-e467-56ae-b986-b466b2749572"
1920
MathOptInterface = "b8f27783-ece8-5eb3-8dc8-9495eed66fee"
2021
cuSPARSE = "b26da814-b3bc-49ef-b0ee-c816305aa060"
@@ -23,19 +24,21 @@ cuSPARSE = "b26da814-b3bc-49ef-b0ee-c816305aa060"
2324
SparseMatrixColoringsCUDAExt = ["CUDA", "cuSPARSE"]
2425
SparseMatrixColoringsCliqueTreesExt = "CliqueTrees"
2526
SparseMatrixColoringsColorsExt = "Colors"
27+
SparseMatrixColoringsGPUArraysExt = "GPUArrays"
2628
SparseMatrixColoringsJuMPExt = ["JuMP", "MathOptInterface"]
2729

2830
[compat]
2931
ADTypes = "1.2.1"
3032
CUDA = "6.0.0"
31-
cuSPARSE = "6.0.0"
3233
CliqueTrees = "1"
3334
Colors = "0.12.11, 0.13"
3435
DocStringExtensions = "0.8,0.9"
36+
GPUArrays = "11.5.0"
3537
JuMP = "1.29.1"
3638
LinearAlgebra = "1"
3739
MathOptInterface = "1.45.0"
3840
PrecompileTools = "1.2.1"
3941
Random = "1"
4042
SparseArrays = "1"
43+
cuSPARSE = "6.0.0"
4144
julia = "1.10"

ext/SparseMatrixColoringsCUDAExt.jl

Lines changed: 0 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -5,88 +5,6 @@ using SparseArrays: SparseMatrixCSC, rowvals, nnz, nzrange
55
using CUDA: CuVector, CuMatrix
66
using cuSPARSE: AbstractCuSparseMatrix, CuSparseMatrixCSC, CuSparseMatrixCSR
77

8-
SMC.matrix_versions(A::AbstractCuSparseMatrix) = (A,)
9-
10-
## Compression (slow, through CPU)
11-
12-
function SMC.compress(A::AbstractCuSparseMatrix, result::SMC.AbstractColoringResult)
13-
return CuMatrix(SMC.compress(SparseMatrixCSC(A), result))
14-
end
15-
16-
## CSC Result
17-
18-
function SMC.ColumnColoringResult(
19-
A::CuSparseMatrixCSC, bg::SMC.BipartiteGraph{T}, color::Vector{<:Integer}
20-
) where {T<:Integer}
21-
group = SMC.group_by_color(T, color)
22-
compressed_indices = SMC.column_csc_indices(bg, color)
23-
additional_info = (; compressed_indices_gpu_csc=CuVector(compressed_indices))
24-
return SMC.ColumnColoringResult(
25-
A, bg, color, group, compressed_indices, additional_info
26-
)
27-
end
28-
29-
function SMC.RowColoringResult(
30-
A::CuSparseMatrixCSC, bg::SMC.BipartiteGraph{T}, color::Vector{<:Integer}
31-
) where {T<:Integer}
32-
group = SMC.group_by_color(T, color)
33-
compressed_indices = SMC.row_csc_indices(bg, color)
34-
additional_info = (; compressed_indices_gpu_csc=CuVector(compressed_indices))
35-
return SMC.RowColoringResult(A, bg, color, group, compressed_indices, additional_info)
36-
end
37-
38-
function SMC.StarSetColoringResult(
39-
A::CuSparseMatrixCSC,
40-
ag::SMC.AdjacencyGraph{T},
41-
color::Vector{<:Integer},
42-
star_set::SMC.StarSet{<:Integer},
43-
) where {T<:Integer}
44-
group = SMC.group_by_color(T, color)
45-
compressed_indices = SMC.star_csc_indices(ag, color, star_set)
46-
additional_info = (; compressed_indices_gpu_csc=CuVector(compressed_indices))
47-
return SMC.StarSetColoringResult(
48-
A, ag, color, group, compressed_indices, additional_info
49-
)
50-
end
51-
52-
## CSR Result
53-
54-
function SMC.ColumnColoringResult(
55-
A::CuSparseMatrixCSR, bg::SMC.BipartiteGraph{T}, color::Vector{<:Integer}
56-
) where {T<:Integer}
57-
group = SMC.group_by_color(T, color)
58-
compressed_indices = SMC.column_csc_indices(bg, color)
59-
compressed_indices_csr = SMC.column_csr_indices(bg, color)
60-
additional_info = (; compressed_indices_gpu_csr=CuVector(compressed_indices_csr))
61-
return SMC.ColumnColoringResult(
62-
A, bg, color, group, compressed_indices, additional_info
63-
)
64-
end
65-
66-
function SMC.RowColoringResult(
67-
A::CuSparseMatrixCSR, bg::SMC.BipartiteGraph{T}, color::Vector{<:Integer}
68-
) where {T<:Integer}
69-
group = SMC.group_by_color(T, color)
70-
compressed_indices = SMC.row_csc_indices(bg, color)
71-
compressed_indices_csr = SMC.row_csr_indices(bg, color)
72-
additional_info = (; compressed_indices_gpu_csr=CuVector(compressed_indices_csr))
73-
return SMC.RowColoringResult(A, bg, color, group, compressed_indices, additional_info)
74-
end
75-
76-
function SMC.StarSetColoringResult(
77-
A::CuSparseMatrixCSR,
78-
ag::SMC.AdjacencyGraph{T},
79-
color::Vector{<:Integer},
80-
star_set::SMC.StarSet{<:Integer},
81-
) where {T<:Integer}
82-
group = SMC.group_by_color(T, color)
83-
compressed_indices = SMC.star_csc_indices(ag, color, star_set)
84-
additional_info = (; compressed_indices_gpu_csr=CuVector(compressed_indices))
85-
return SMC.StarSetColoringResult(
86-
A, ag, color, group, compressed_indices, additional_info
87-
)
88-
end
89-
908
## Decompression
919

9210
for R in (:ColumnColoringResult, :RowColoringResult)
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
module SparseMatrixColoringsGPUArraysExt
2+
3+
using GPUArrays: dense_array_type
4+
using SparseArrays: SparseMatrixCSC
5+
import SparseMatrixColorings as SMC
6+
7+
SMC.matrix_versions(A::AbstractGPUSparseMatrix) = (A,)
8+
9+
## Compression (slow, through CPU)
10+
11+
function SMC.compress(A::AbstractGPUSparseMatrix, result::SMC.AbstractColoringResult)
12+
A_cpu = SparseMatrixCSC(A)
13+
B_cpu = SMC.compress(A_cpu, result)
14+
B = dense_array_type(A)(B_cpu)
15+
return B
16+
end
17+
18+
## CSC Result
19+
20+
function SMC.ColumnColoringResult(
21+
A::CuSparseMatrixCSC, bg::SMC.BipartiteGraph{T}, color::Vector{<:Integer}
22+
) where {T<:Integer}
23+
group = SMC.group_by_color(T, color)
24+
compressed_indices = SMC.column_csc_indices(bg, color)
25+
additional_info = (; compressed_indices_gpu_csc=dense_array_type(A)(compressed_indices))
26+
return SMC.ColumnColoringResult(
27+
A, bg, color, group, compressed_indices, additional_info
28+
)
29+
end
30+
31+
function SMC.RowColoringResult(
32+
A::CuSparseMatrixCSC, bg::SMC.BipartiteGraph{T}, color::Vector{<:Integer}
33+
) where {T<:Integer}
34+
group = SMC.group_by_color(T, color)
35+
compressed_indices = SMC.row_csc_indices(bg, color)
36+
additional_info = (; compressed_indices_gpu_csc=dense_array_type(A)(compressed_indices))
37+
return SMC.RowColoringResult(A, bg, color, group, compressed_indices, additional_info)
38+
end
39+
40+
function SMC.StarSetColoringResult(
41+
A::CuSparseMatrixCSC,
42+
ag::SMC.AdjacencyGraph{T},
43+
color::Vector{<:Integer},
44+
star_set::SMC.StarSet{<:Integer},
45+
) where {T<:Integer}
46+
group = SMC.group_by_color(T, color)
47+
compressed_indices = SMC.star_csc_indices(ag, color, star_set)
48+
additional_info = (; compressed_indices_gpu_csc=dense_array_type(A)(compressed_indices))
49+
return SMC.StarSetColoringResult(
50+
A, ag, color, group, compressed_indices, additional_info
51+
)
52+
end
53+
54+
## CSR Result
55+
56+
function SMC.ColumnColoringResult(
57+
A::CuSparseMatrixCSR, bg::SMC.BipartiteGraph{T}, color::Vector{<:Integer}
58+
) where {T<:Integer}
59+
group = SMC.group_by_color(T, color)
60+
compressed_indices = SMC.column_csc_indices(bg, color)
61+
compressed_indices_csr = SMC.column_csr_indices(bg, color)
62+
additional_info = (;
63+
compressed_indices_gpu_csr=dense_array_type(A)(compressed_indices_csr)
64+
)
65+
return SMC.ColumnColoringResult(
66+
A, bg, color, group, compressed_indices, additional_info
67+
)
68+
end
69+
70+
function SMC.RowColoringResult(
71+
A::CuSparseMatrixCSR, bg::SMC.BipartiteGraph{T}, color::Vector{<:Integer}
72+
) where {T<:Integer}
73+
group = SMC.group_by_color(T, color)
74+
compressed_indices = SMC.row_csc_indices(bg, color)
75+
compressed_indices_csr = SMC.row_csr_indices(bg, color)
76+
additional_info = (;
77+
compressed_indices_gpu_csr=dense_array_type(A)(compressed_indices_csr)
78+
)
79+
return SMC.RowColoringResult(A, bg, color, group, compressed_indices, additional_info)
80+
end
81+
82+
function SMC.StarSetColoringResult(
83+
A::CuSparseMatrixCSR,
84+
ag::SMC.AdjacencyGraph{T},
85+
color::Vector{<:Integer},
86+
star_set::SMC.StarSet{<:Integer},
87+
) where {T<:Integer}
88+
group = SMC.group_by_color(T, color)
89+
compressed_indices = SMC.star_csc_indices(ag, color, star_set)
90+
additional_info = (; compressed_indices_gpu_csr=dense_array_type(A)(compressed_indices))
91+
return SMC.StarSetColoringResult(
92+
A, ag, color, group, compressed_indices, additional_info
93+
)
94+
end
95+
96+
end

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ include("utils.jl")
1212
@testset verbose = true "SparseMatrixColorings" begin
1313
if get(ENV, "JULIA_SMC_TEST_GROUP", nothing) == "GPU"
1414
@testset "CUDA" begin
15-
using CUDA
15+
using CUDA, cuSPARSE
1616
include("cuda.jl")
1717
end
1818
else

0 commit comments

Comments
 (0)