@@ -2,7 +2,8 @@ module SparseMatrixColoringsCUDAExt
22
33import SparseMatrixColorings as SMC
44using SparseArrays: SparseMatrixCSC, rowvals, nnz, nzrange
5- using CUDA: CuVector, CuMatrix
5+ using CUDA:
6+ @cuda , CuVector, CuMatrix, blockIdx, blockDim, gridDim, threadIdx, launch_configuration
67using CUDA. CUSPARSE: AbstractCuSparseMatrix, CuSparseMatrixCSC, CuSparseMatrixCSR
78
89SMC. matrix_versions (A:: AbstractCuSparseMatrix ) = (A,)
@@ -53,29 +54,32 @@ function SMC.StarSetColoringResult(
5354 return SMC. StarSetColoringResult (A, ag, color, result_cpu. group, compressed_indices)
5455end
5556
56- # TODO : write a kernel
57-
58- function SMC. decompress! (
59- A:: CuSparseMatrixCSC , B:: CuMatrix , result:: SMC.ColumnColoringResult{<:CuSparseMatrixCSC}
60- )
61- A. nzVal .= getindex .(Ref (B), result. compressed_indices)
62- return A
63- end
64-
65- function SMC. decompress! (
66- A:: CuSparseMatrixCSC , B:: CuMatrix , result:: SMC.RowColoringResult{<:CuSparseMatrixCSC}
57+ function update_nzval_from_matrix! (
58+ nzVal:: AbstractVector , B:: AbstractMatrix , compressed_indices:: AbstractVector{<:Integer}
6759)
68- A. nzVal .= getindex .(Ref (B), result. compressed_indices)
69- return A
60+ index = (blockIdx (). x - 1 ) * blockDim (). x + threadIdx (). x
61+ stride = gridDim (). x * blockDim (). x
62+ for k in index: stride: length (nzVal)
63+ nzVal[k] = B[compressed_indices[k]]
64+ end
65+ return nothing
7066end
7167
72- function SMC. decompress! (
73- A:: CuSparseMatrixCSC ,
74- B:: CuMatrix ,
75- result:: SMC.StarSetColoringResult{<:CuSparseMatrixCSC} ,
76- )
77- A. nzVal .= getindex .(Ref (B), result. compressed_indices)
78- return A
68+ for R in (:ColumnColoringResult , :RowColoringResult , :StarSetColoringResult )
69+ # loop to avoid method ambiguity
70+ @eval function SMC. decompress! (
71+ A:: CuSparseMatrixCSC , B:: CuMatrix , result:: SMC. $ R{<: CuSparseMatrixCSC }
72+ )
73+ A. nnz == 0 && return A
74+ kernel = @cuda launch = false update_nzval_from_matrix! (
75+ A. nzVal, B, result. compressed_indices
76+ )
77+ config = launch_configuration (kernel. fun)
78+ threads = min (A. nnz, config. threads)
79+ blocks = cld (A. nnz, threads)
80+ kernel (A. nzVal, B, result. compressed_indices; threads, blocks)
81+ return A
82+ end
7983end
8084
8185end
0 commit comments