@@ -2,8 +2,7 @@ module SparseMatrixColoringsCUDAExt
22
33import SparseMatrixColorings as SMC
44using SparseArrays: SparseMatrixCSC, rowvals, nnz, nzrange
5- using CUDA:
6- @cuda , CuVector, CuMatrix, blockIdx, blockDim, gridDim, threadIdx, launch_configuration
5+ using CUDA: CuVector, CuMatrix
76using CUDA. CUSPARSE: AbstractCuSparseMatrix, CuSparseMatrixCSC, CuSparseMatrixCSR
87
98SMC. matrix_versions (A:: AbstractCuSparseMatrix ) = (A,)
9897
9998# # Decompression
10099
101- # COV_EXCL_START
102- function update_nzval_from_matrix! (
103- nzVal:: AbstractVector , B:: AbstractMatrix , compressed_indices:: AbstractVector{<:Integer}
104- )
105- index = (blockIdx (). x - 1 ) * blockDim (). x + threadIdx (). x
106- stride = gridDim (). x * blockDim (). x
107- for k in index: stride: length (nzVal)
108- nzVal[k] = B[compressed_indices[k]]
109- end
110- return nothing
111- end
112- # COV_EXCL_STOP
113-
114100for R in (:ColumnColoringResult , :RowColoringResult , :StarSetColoringResult )
115101 # loop to avoid method ambiguity
116102 @eval function SMC. decompress! (
117103 A:: CuSparseMatrixCSC , B:: CuMatrix , result:: SMC. $ R{<: CuSparseMatrixCSC }
118104 )
119105 compressed_indices = result. additional_info. compressed_indices_gpu_csc
120- A. nnz == 0 && return A
121- kernel = @cuda launch = false update_nzval_from_matrix! (
122- A. nzVal, B, compressed_indices
123- )
124- config = launch_configuration (kernel. fun)
125- threads = min (A. nnz, config. threads)
126- blocks = cld (A. nnz, threads)
127- kernel (A. nzVal, B, compressed_indices; threads, blocks)
106+ map! (Base. Fix1 (getindex, B), A. nzVal, compressed_indices)
128107 return A
129108 end
130109
131110 @eval function SMC. decompress! (
132111 A:: CuSparseMatrixCSR , B:: CuMatrix , result:: SMC. $ R{<: CuSparseMatrixCSR }
133112 )
134113 compressed_indices = result. additional_info. compressed_indices_gpu_csr
135- A. nnz == 0 && return A
136- kernel = @cuda launch = false update_nzval_from_matrix! (
137- A. nzVal, B, compressed_indices
138- )
139- config = launch_configuration (kernel. fun)
140- threads = min (A. nnz, config. threads)
141- blocks = cld (A. nnz, threads)
142- kernel (A. nzVal, B, compressed_indices; threads, blocks)
114+ map! (Base. Fix1 (getindex, B), A. nzVal, compressed_indices)
143115 return A
144116 end
145117end
0 commit comments