Skip to content

Commit f7daa2f

Browse files
committed
Write CUDA kernel
1 parent a6110ec commit f7daa2f

1 file changed

Lines changed: 25 additions & 21 deletions

File tree

ext/SparseMatrixColoringsCUDAExt.jl

Lines changed: 25 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@ module SparseMatrixColoringsCUDAExt
22

33
import SparseMatrixColorings as SMC
44
using SparseArrays: SparseMatrixCSC, rowvals, nnz, nzrange
5-
using CUDA: CuVector, CuMatrix
5+
using CUDA:
6+
@cuda, CuVector, CuMatrix, blockIdx, blockDim, gridDim, threadIdx, launch_configuration
67
using CUDA.CUSPARSE: AbstractCuSparseMatrix, CuSparseMatrixCSC, CuSparseMatrixCSR
78

89
SMC.matrix_versions(A::AbstractCuSparseMatrix) = (A,)
@@ -53,29 +54,32 @@ function SMC.StarSetColoringResult(
5354
return SMC.StarSetColoringResult(A, ag, color, result_cpu.group, compressed_indices)
5455
end
5556

56-
# TODO: write a kernel
57-
58-
function SMC.decompress!(
59-
A::CuSparseMatrixCSC, B::CuMatrix, result::SMC.ColumnColoringResult{<:CuSparseMatrixCSC}
60-
)
61-
A.nzVal .= getindex.(Ref(B), result.compressed_indices)
62-
return A
63-
end
64-
65-
function SMC.decompress!(
66-
A::CuSparseMatrixCSC, B::CuMatrix, result::SMC.RowColoringResult{<:CuSparseMatrixCSC}
57+
function update_nzval_from_matrix!(
58+
nzVal::AbstractVector, B::AbstractMatrix, compressed_indices::AbstractVector{<:Integer}
6759
)
68-
A.nzVal .= getindex.(Ref(B), result.compressed_indices)
69-
return A
60+
index = (blockIdx().x - 1) * blockDim().x + threadIdx().x
61+
stride = gridDim().x * blockDim().x
62+
for k in index:stride:length(nzVal)
63+
nzVal[k] = B[compressed_indices[k]]
64+
end
65+
return nothing
7066
end
7167

72-
function SMC.decompress!(
73-
A::CuSparseMatrixCSC,
74-
B::CuMatrix,
75-
result::SMC.StarSetColoringResult{<:CuSparseMatrixCSC},
76-
)
77-
A.nzVal .= getindex.(Ref(B), result.compressed_indices)
78-
return A
68+
for R in (:ColumnColoringResult, :RowColoringResult, :StarSetColoringResult)
69+
# loop to avoid method ambiguity
70+
@eval function SMC.decompress!(
71+
A::CuSparseMatrixCSC, B::CuMatrix, result::SMC.$R{<:CuSparseMatrixCSC}
72+
)
73+
A.nnz == 0 && return A
74+
kernel = @cuda launch = false update_nzval_from_matrix!(
75+
A.nzVal, B, result.compressed_indices
76+
)
77+
config = launch_configuration(kernel.fun)
78+
threads = min(A.nnz, config.threads)
79+
blocks = cld(A.nnz, threads)
80+
kernel(A.nzVal, B, result.compressed_indices; threads, blocks)
81+
return A
82+
end
7983
end
8084

8185
end

0 commit comments

Comments
 (0)