@@ -2,7 +2,7 @@ module SparseMatrixColoringsCUDAExt
22
33import SparseMatrixColorings as SMC
44using SparseArrays: SparseMatrixCSC, rowvals, nnz, nzrange
5- using CUDA: CuVector, CuMatrix, CuRef
5+ using CUDA: CuVector, CuMatrix
66using CUDA. CUSPARSE: AbstractCuSparseMatrix, CuSparseMatrixCSC, CuSparseMatrixCSR
77
88SMC. matrix_versions (A:: AbstractCuSparseMatrix ) = (A,)
@@ -53,17 +53,19 @@ function SMC.StarSetColoringResult(
5353 return SMC. StarSetColoringResult (A, ag, color, result_cpu. group, compressed_indices)
5454end
5555
56+ # TODO : write a kernel
57+
5658function SMC. decompress! (
5759 A:: CuSparseMatrixCSC , B:: CuMatrix , result:: SMC.ColumnColoringResult{<:CuSparseMatrixCSC}
5860)
59- A. nzVal .= getindex .(CuRef (B), result. compressed_indices)
61+ A. nzVal .= getindex .(Ref (B), result. compressed_indices)
6062 return A
6163end
6264
6365function SMC. decompress! (
6466 A:: CuSparseMatrixCSC , B:: CuMatrix , result:: SMC.RowColoringResult{<:CuSparseMatrixCSC}
6567)
66- A. nzVal .= getindex .(CuRef (B), result. compressed_indices)
68+ A. nzVal .= getindex .(Ref (B), result. compressed_indices)
6769 return A
6870end
6971
@@ -72,7 +74,7 @@ function SMC.decompress!(
7274 B:: CuMatrix ,
7375 result:: SMC.StarSetColoringResult{<:CuSparseMatrixCSC} ,
7476)
75- A. nzVal .= getindex .(CuRef (B), result. compressed_indices)
77+ A. nzVal .= getindex .(Ref (B), result. compressed_indices)
7678 return A
7779end
7880
0 commit comments