@@ -2,7 +2,7 @@ module SparseMatrixColoringsCUDAExt
22
33import SparseMatrixColorings as SMC
44using SparseArrays: SparseMatrixCSC, rowvals, nnz, nzrange
5- using CUDA: CuVector, CuMatrix
5+ using CUDA: CuVector, CuMatrix, CuRef
66using CUDA. CUSPARSE: AbstractCuSparseMatrix, CuSparseMatrixCSC, CuSparseMatrixCSR
77
88SMC. matrix_versions (A:: AbstractCuSparseMatrix ) = (A,)
5656function SMC. decompress! (
5757 A:: CuSparseMatrixCSC , B:: CuMatrix , result:: SMC.ColumnColoringResult{<:CuSparseMatrixCSC}
5858)
59- A. nzVal .= getindex .(Ref (B), result. compressed_indices)
59+ A. nzVal .= getindex .(CuRef (B), result. compressed_indices)
6060 return A
6161end
6262
6363function SMC. decompress! (
6464 A:: CuSparseMatrixCSC , B:: CuMatrix , result:: SMC.RowColoringResult{<:CuSparseMatrixCSC}
6565)
66- A. nzVal .= getindex .(Ref (B), result. compressed_indices)
66+ A. nzVal .= getindex .(CuRef (B), result. compressed_indices)
6767 return A
6868end
6969
@@ -72,7 +72,7 @@ function SMC.decompress!(
7272 B:: CuMatrix ,
7373 result:: SMC.StarSetColoringResult{<:CuSparseMatrixCSC} ,
7474)
75- A. nzVal .= getindex .(Ref (B), result. compressed_indices)
75+ A. nzVal .= getindex .(CuRef (B), result. compressed_indices)
7676 return A
7777end
7878
0 commit comments