Skip to content

Commit 1d2359e

Browse files
committed
No generic result
1 parent 6c4472f commit 1d2359e

File tree

2 files changed

+74
-78
lines changed

2 files changed

+74
-78
lines changed

ext/SparseMatrixColoringsCUDAExt.jl

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,80 @@ using SparseArrays: SparseMatrixCSC, rowvals, nnz, nzrange
55
using CUDA: CuVector, CuMatrix
66
using cuSPARSE: AbstractCuSparseMatrix, CuSparseMatrixCSC, CuSparseMatrixCSR
77

8+
## CSC Result
9+
10+
function SMC.ColumnColoringResult(
11+
A::CuSparseMatrixCSC, bg::SMC.BipartiteGraph{T}, color::Vector{<:Integer}
12+
) where {T<:Integer}
13+
group = SMC.group_by_color(T, color)
14+
compressed_indices = SMC.column_csc_indices(bg, color)
15+
additional_info = (; compressed_indices_gpu_csc=CuVector(compressed_indices))
16+
return SMC.ColumnColoringResult(
17+
A, bg, color, group, compressed_indices, additional_info
18+
)
19+
end
20+
21+
function SMC.RowColoringResult(
22+
A::CuSparseMatrixCSC, bg::SMC.BipartiteGraph{T}, color::Vector{<:Integer}
23+
) where {T<:Integer}
24+
group = SMC.group_by_color(T, color)
25+
compressed_indices = SMC.row_csc_indices(bg, color)
26+
additional_info = (; compressed_indices_gpu_csc=CuVector(compressed_indices))
27+
return SMC.RowColoringResult(A, bg, color, group, compressed_indices, additional_info)
28+
end
29+
30+
function SMC.StarSetColoringResult(
31+
A::CuSparseMatrixCSC,
32+
ag::SMC.AdjacencyGraph{T},
33+
color::Vector{<:Integer},
34+
star_set::SMC.StarSet{<:Integer},
35+
) where {T<:Integer}
36+
group = SMC.group_by_color(T, color)
37+
compressed_indices = SMC.star_csc_indices(ag, color, star_set)
38+
additional_info = (; compressed_indices_gpu_csc=CuVector(compressed_indices))
39+
return SMC.StarSetColoringResult(
40+
A, ag, color, group, compressed_indices, additional_info
41+
)
42+
end
43+
44+
## CSR Result
45+
46+
function SMC.ColumnColoringResult(
47+
A::CuSparseMatrixCSR, bg::SMC.BipartiteGraph{T}, color::Vector{<:Integer}
48+
) where {T<:Integer}
49+
group = SMC.group_by_color(T, color)
50+
compressed_indices = SMC.column_csc_indices(bg, color)
51+
compressed_indices_csr = SMC.column_csr_indices(bg, color)
52+
additional_info = (; compressed_indices_gpu_csr=CuVector(compressed_indices_csr))
53+
return SMC.ColumnColoringResult(
54+
A, bg, color, group, compressed_indices, additional_info
55+
)
56+
end
57+
58+
function SMC.RowColoringResult(
59+
A::CuSparseMatrixCSR, bg::SMC.BipartiteGraph{T}, color::Vector{<:Integer}
60+
) where {T<:Integer}
61+
group = SMC.group_by_color(T, color)
62+
compressed_indices = SMC.row_csc_indices(bg, color)
63+
compressed_indices_csr = SMC.row_csr_indices(bg, color)
64+
additional_info = (; compressed_indices_gpu_csr=CuVector(compressed_indices_csr))
65+
return SMC.RowColoringResult(A, bg, color, group, compressed_indices, additional_info)
66+
end
67+
68+
function SMC.StarSetColoringResult(
69+
A::CuSparseMatrixCSR,
70+
ag::SMC.AdjacencyGraph{T},
71+
color::Vector{<:Integer},
72+
star_set::SMC.StarSet{<:Integer},
73+
) where {T<:Integer}
74+
group = SMC.group_by_color(T, color)
75+
compressed_indices = SMC.star_csc_indices(ag, color, star_set)
76+
additional_info = (; compressed_indices_gpu_csr=CuVector(compressed_indices))
77+
return SMC.StarSetColoringResult(
78+
A, ag, color, group, compressed_indices, additional_info
79+
)
80+
end
81+
882
## Decompression
983

1084
for R in (:ColumnColoringResult, :RowColoringResult)

ext/SparseMatrixColoringsGPUArraysExt.jl

Lines changed: 0 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -15,82 +15,4 @@ function SMC.compress(A::AbstractGPUSparseMatrix, result::SMC.AbstractColoringRe
1515
return B
1616
end
1717

18-
## CSC Result
19-
20-
function SMC.ColumnColoringResult(
21-
A::CuSparseMatrixCSC, bg::SMC.BipartiteGraph{T}, color::Vector{<:Integer}
22-
) where {T<:Integer}
23-
group = SMC.group_by_color(T, color)
24-
compressed_indices = SMC.column_csc_indices(bg, color)
25-
additional_info = (; compressed_indices_gpu_csc=dense_array_type(A)(compressed_indices))
26-
return SMC.ColumnColoringResult(
27-
A, bg, color, group, compressed_indices, additional_info
28-
)
29-
end
30-
31-
function SMC.RowColoringResult(
32-
A::CuSparseMatrixCSC, bg::SMC.BipartiteGraph{T}, color::Vector{<:Integer}
33-
) where {T<:Integer}
34-
group = SMC.group_by_color(T, color)
35-
compressed_indices = SMC.row_csc_indices(bg, color)
36-
additional_info = (; compressed_indices_gpu_csc=dense_array_type(A)(compressed_indices))
37-
return SMC.RowColoringResult(A, bg, color, group, compressed_indices, additional_info)
38-
end
39-
40-
function SMC.StarSetColoringResult(
41-
A::CuSparseMatrixCSC,
42-
ag::SMC.AdjacencyGraph{T},
43-
color::Vector{<:Integer},
44-
star_set::SMC.StarSet{<:Integer},
45-
) where {T<:Integer}
46-
group = SMC.group_by_color(T, color)
47-
compressed_indices = SMC.star_csc_indices(ag, color, star_set)
48-
additional_info = (; compressed_indices_gpu_csc=dense_array_type(A)(compressed_indices))
49-
return SMC.StarSetColoringResult(
50-
A, ag, color, group, compressed_indices, additional_info
51-
)
52-
end
53-
54-
## CSR Result
55-
56-
function SMC.ColumnColoringResult(
57-
A::CuSparseMatrixCSR, bg::SMC.BipartiteGraph{T}, color::Vector{<:Integer}
58-
) where {T<:Integer}
59-
group = SMC.group_by_color(T, color)
60-
compressed_indices = SMC.column_csc_indices(bg, color)
61-
compressed_indices_csr = SMC.column_csr_indices(bg, color)
62-
additional_info = (;
63-
compressed_indices_gpu_csr=dense_array_type(A)(compressed_indices_csr)
64-
)
65-
return SMC.ColumnColoringResult(
66-
A, bg, color, group, compressed_indices, additional_info
67-
)
68-
end
69-
70-
function SMC.RowColoringResult(
71-
A::CuSparseMatrixCSR, bg::SMC.BipartiteGraph{T}, color::Vector{<:Integer}
72-
) where {T<:Integer}
73-
group = SMC.group_by_color(T, color)
74-
compressed_indices = SMC.row_csc_indices(bg, color)
75-
compressed_indices_csr = SMC.row_csr_indices(bg, color)
76-
additional_info = (;
77-
compressed_indices_gpu_csr=dense_array_type(A)(compressed_indices_csr)
78-
)
79-
return SMC.RowColoringResult(A, bg, color, group, compressed_indices, additional_info)
80-
end
81-
82-
function SMC.StarSetColoringResult(
83-
A::CuSparseMatrixCSR,
84-
ag::SMC.AdjacencyGraph{T},
85-
color::Vector{<:Integer},
86-
star_set::SMC.StarSet{<:Integer},
87-
) where {T<:Integer}
88-
group = SMC.group_by_color(T, color)
89-
compressed_indices = SMC.star_csc_indices(ag, color, star_set)
90-
additional_info = (; compressed_indices_gpu_csr=dense_array_type(A)(compressed_indices))
91-
return SMC.StarSetColoringResult(
92-
A, ag, color, group, compressed_indices, additional_info
93-
)
94-
end
95-
9618
end

0 commit comments

Comments
 (0)