Skip to content

Commit 52d9c6a

Browse files
committed
CuRef
1 parent 649a3c8 commit 52d9c6a

1 file changed

Lines changed: 4 additions & 4 deletions

File tree

ext/SparseMatrixColoringsCUDAExt.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ module SparseMatrixColoringsCUDAExt
22

33
import SparseMatrixColorings as SMC
44
using SparseArrays: SparseMatrixCSC, rowvals, nnz, nzrange
5-
using CUDA: CuVector, CuMatrix
5+
using CUDA: CuVector, CuMatrix, CuRef
66
using CUDA.CUSPARSE: AbstractCuSparseMatrix, CuSparseMatrixCSC, CuSparseMatrixCSR
77

88
SMC.matrix_versions(A::AbstractCuSparseMatrix) = (A,)
@@ -56,14 +56,14 @@ end
5656
function SMC.decompress!(
5757
A::CuSparseMatrixCSC, B::CuMatrix, result::SMC.ColumnColoringResult{<:CuSparseMatrixCSC}
5858
)
59-
A.nzVal .= getindex.(Ref(B), result.compressed_indices)
59+
A.nzVal .= getindex.(CuRef(B), result.compressed_indices)
6060
return A
6161
end
6262

6363
function SMC.decompress!(
6464
A::CuSparseMatrixCSC, B::CuMatrix, result::SMC.RowColoringResult{<:CuSparseMatrixCSC}
6565
)
66-
A.nzVal .= getindex.(Ref(B), result.compressed_indices)
66+
A.nzVal .= getindex.(CuRef(B), result.compressed_indices)
6767
return A
6868
end
6969

@@ -72,7 +72,7 @@ function SMC.decompress!(
7272
B::CuMatrix,
7373
result::SMC.StarSetColoringResult{<:CuSparseMatrixCSC},
7474
)
75-
A.nzVal .= getindex.(Ref(B), result.compressed_indices)
75+
A.nzVal .= getindex.(CuRef(B), result.compressed_indices)
7676
return A
7777
end
7878

0 commit comments

Comments
 (0)