@@ -22,24 +22,26 @@ function SMC.compress(
2222 return CuMatrix (SMC. compress (SparseMatrixCSC (A), result))
2323end
2424
25- # # CSC
25+ # # CSC Result
2626
2727function SMC. ColumnColoringResult (
2828 A:: CuSparseMatrixCSC , bg:: SMC.BipartiteGraph{T} , color:: Vector{<:Integer}
2929) where {T<: Integer }
30- A_cpu = SparseMatrixCSC (A)
31- result_cpu = SMC. ColumnColoringResult (A_cpu, bg, color)
32- compressed_indices = CuVector (result_cpu. compressed_indices)
33- return SMC. ColumnColoringResult (A, bg, color, result_cpu. group, compressed_indices)
30+ group = SMC. group_by_color (T, color)
31+ compressed_indices = SMC. column_csc_indices (bg, color)
32+ additional_info = (; compressed_indices_gpu_csc= CuVector (compressed_indices))
33+ return SMC. ColumnColoringResult (
34+ A, bg, color, group, compressed_indices, additional_info
35+ )
3436end
3537
3638function SMC. RowColoringResult (
3739 A:: CuSparseMatrixCSC , bg:: SMC.BipartiteGraph{T} , color:: Vector{<:Integer}
3840) where {T<: Integer }
39- A_cpu = SparseMatrixCSC (A )
40- result_cpu = SMC. RowColoringResult (A_cpu, bg, color)
41- compressed_indices = CuVector (result_cpu . compressed_indices)
42- return SMC. RowColoringResult (A, bg, color, result_cpu . group, compressed_indices)
41+ group = SMC . group_by_color (T, color )
42+ compressed_indices = SMC. row_csc_indices ( bg, color)
43+ additional_info = (; compressed_indices_gpu_csc = CuVector (compressed_indices) )
44+ return SMC. RowColoringResult (A, bg, color, group, compressed_indices, additional_info )
4345end
4446
4547function SMC. StarSetColoringResult (
@@ -48,12 +50,54 @@ function SMC.StarSetColoringResult(
4850 color:: Vector{<:Integer} ,
4951 star_set:: SMC.StarSet{<:Integer} ,
5052) where {T<: Integer }
51- A_cpu = SparseMatrixCSC (A)
52- result_cpu = SMC. StarSetColoringResult (A_cpu, ag, color, star_set)
53- compressed_indices = CuVector (result_cpu. compressed_indices)
54- return SMC. StarSetColoringResult (A, ag, color, result_cpu. group, compressed_indices)
53+ group = SMC. group_by_color (T, color)
54+ compressed_indices = SMC. star_csc_indices (ag, color, star_set)
55+ additional_info = (; compressed_indices_gpu_csc= CuVector (compressed_indices))
56+ return SMC. StarSetColoringResult (
57+ A, ag, color, group, compressed_indices, additional_info
58+ )
59+ end
60+
61+ # # CSR Result
62+
63+ function SMC. ColumnColoringResult (
64+ A:: CuSparseMatrixCSR , bg:: SMC.BipartiteGraph{T} , color:: Vector{<:Integer}
65+ ) where {T<: Integer }
66+ group = SMC. group_by_color (T, color)
67+ compressed_indices = SMC. column_csc_indices (bg, color)
68+ compressed_indices_csr = SMC. column_csr_indices (bg, color)
69+ additional_info = (; compressed_indices_gpu_csr= CuVector (compressed_indices_csr))
70+ return SMC. ColumnColoringResult (
71+ A, bg, color, group, compressed_indices, additional_info
72+ )
73+ end
74+
75+ function SMC. RowColoringResult (
76+ A:: CuSparseMatrixCSR , bg:: SMC.BipartiteGraph{T} , color:: Vector{<:Integer}
77+ ) where {T<: Integer }
78+ group = SMC. group_by_color (T, color)
79+ compressed_indices = SMC. row_csc_indices (bg, color)
80+ compressed_indices_csr = SMC. row_csr_indices (bg, color)
81+ additional_info = (; compressed_indices_gpu_csr= CuVector (compressed_indices_csr))
82+ return SMC. RowColoringResult (A, bg, color, group, compressed_indices, additional_info)
83+ end
84+
85+ function SMC. StarSetColoringResult (
86+ A:: CuSparseMatrixCSR ,
87+ ag:: SMC.AdjacencyGraph{T} ,
88+ color:: Vector{<:Integer} ,
89+ star_set:: SMC.StarSet{<:Integer} ,
90+ ) where {T<: Integer }
91+ group = SMC. group_by_color (T, color)
92+ compressed_indices = SMC. star_csc_indices (ag, color, star_set)
93+ additional_info = (; compressed_indices_gpu_csr= CuVector (compressed_indices))
94+ return SMC. StarSetColoringResult (
95+ A, ag, color, group, compressed_indices, additional_info
96+ )
5597end
5698
99+ # # Decompression
100+
57101function update_nzval_from_matrix! (
58102 nzVal:: AbstractVector , B:: AbstractMatrix , compressed_indices:: AbstractVector{<:Integer}
59103)
@@ -70,14 +114,30 @@ for R in (:ColumnColoringResult, :RowColoringResult, :StarSetColoringResult)
70114 @eval function SMC. decompress! (
71115 A:: CuSparseMatrixCSC , B:: CuMatrix , result:: SMC. $ R{<: CuSparseMatrixCSC }
72116 )
117+ compressed_indices = result. additional_info. compressed_indices_gpu_csc
118+ A. nnz == 0 && return A
119+ kernel = @cuda launch = false update_nzval_from_matrix! (
120+ A. nzVal, B, compressed_indices
121+ )
122+ config = launch_configuration (kernel. fun)
123+ threads = min (A. nnz, config. threads)
124+ blocks = cld (A. nnz, threads)
125+ kernel (A. nzVal, B, compressed_indices; threads, blocks)
126+ return A
127+ end
128+
129+ @eval function SMC. decompress! (
130+ A:: CuSparseMatrixCSR , B:: CuMatrix , result:: SMC. $ R{<: CuSparseMatrixCSR }
131+ )
132+ compressed_indices = result. additional_info. compressed_indices_gpu_csr
73133 A. nnz == 0 && return A
74134 kernel = @cuda launch = false update_nzval_from_matrix! (
75- A. nzVal, B, result . compressed_indices
135+ A. nzVal, B, compressed_indices
76136 )
77137 config = launch_configuration (kernel. fun)
78138 threads = min (A. nnz, config. threads)
79139 blocks = cld (A. nnz, threads)
80- kernel (A. nzVal, B, result . compressed_indices; threads, blocks)
140+ kernel (A. nzVal, B, compressed_indices; threads, blocks)
81141 return A
82142 end
83143end
0 commit comments