Skip to content

Commit a6110ec

Browse files
committed
Remove CuRef
1 parent 51a7b54 commit a6110ec

1 file changed

Lines changed: 6 additions & 4 deletions

File tree

ext/SparseMatrixColoringsCUDAExt.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ module SparseMatrixColoringsCUDAExt
22

33
import SparseMatrixColorings as SMC
44
using SparseArrays: SparseMatrixCSC, rowvals, nnz, nzrange
5-
using CUDA: CuVector, CuMatrix, CuRef
5+
using CUDA: CuVector, CuMatrix
66
using CUDA.CUSPARSE: AbstractCuSparseMatrix, CuSparseMatrixCSC, CuSparseMatrixCSR
77

88
SMC.matrix_versions(A::AbstractCuSparseMatrix) = (A,)
@@ -53,17 +53,19 @@ function SMC.StarSetColoringResult(
5353
return SMC.StarSetColoringResult(A, ag, color, result_cpu.group, compressed_indices)
5454
end
5555

56+
# TODO: write a kernel
57+
5658
function SMC.decompress!(
5759
A::CuSparseMatrixCSC, B::CuMatrix, result::SMC.ColumnColoringResult{<:CuSparseMatrixCSC}
5860
)
59-
A.nzVal .= getindex.(CuRef(B), result.compressed_indices)
61+
A.nzVal .= getindex.(Ref(B), result.compressed_indices)
6062
return A
6163
end
6264

6365
function SMC.decompress!(
6466
A::CuSparseMatrixCSC, B::CuMatrix, result::SMC.RowColoringResult{<:CuSparseMatrixCSC}
6567
)
66-
A.nzVal .= getindex.(CuRef(B), result.compressed_indices)
68+
A.nzVal .= getindex.(Ref(B), result.compressed_indices)
6769
return A
6870
end
6971

@@ -72,7 +74,7 @@ function SMC.decompress!(
7274
B::CuMatrix,
7375
result::SMC.StarSetColoringResult{<:CuSparseMatrixCSC},
7476
)
75-
A.nzVal .= getindex.(CuRef(B), result.compressed_indices)
77+
A.nzVal .= getindex.(Ref(B), result.compressed_indices)
7678
return A
7779
end
7880

0 commit comments

Comments
 (0)