Skip to content

Commit cba4d7b

Browse files
committed
Fix the tests with CUDA.jl
1 parent 6c098da commit cba4d7b

1 file changed

Lines changed: 10 additions & 6 deletions

File tree

ext/SparseMatrixColoringsCUDAExt.jl

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,13 +47,15 @@ function SMC.StarSetColoringResult(
4747
A::CuSparseMatrixCSC,
4848
ag::SMC.AdjacencyGraph{T},
4949
color::Vector{<:Integer},
50-
star_set::SMC.StarSet{<:Integer},
50+
star_set::SMC.StarSet{<:Integer};
51+
decompression_uplo::Symbol=:F,
5152
) where {T<:Integer}
53+
@assert decompression_uplo == :F
5254
group = SMC.group_by_color(T, color)
53-
compressed_indices = SMC.star_csc_indices(ag, color, star_set)
55+
compressed_indices = SMC.star_csc_indices(ag, color, star_set, decompression_uplo)
5456
additional_info = (; compressed_indices_gpu_csc=CuVector(compressed_indices))
5557
return SMC.StarSetColoringResult(
56-
A, ag, color, group, compressed_indices, additional_info
58+
A, ag, color, group, compressed_indices, decompression_uplo, additional_info
5759
)
5860
end
5961

@@ -85,13 +87,15 @@ function SMC.StarSetColoringResult(
8587
A::CuSparseMatrixCSR,
8688
ag::SMC.AdjacencyGraph{T},
8789
color::Vector{<:Integer},
88-
star_set::SMC.StarSet{<:Integer},
90+
star_set::SMC.StarSet{<:Integer};
91+
decompression_uplo::Symbol=:F,
8992
) where {T<:Integer}
93+
@assert decompression_uplo == :F
9094
group = SMC.group_by_color(T, color)
91-
compressed_indices = SMC.star_csc_indices(ag, color, star_set)
95+
compressed_indices = SMC.star_csc_indices(ag, color, star_set, decompression_uplo)
9296
additional_info = (; compressed_indices_gpu_csr=CuVector(compressed_indices))
9397
return SMC.StarSetColoringResult(
94-
A, ag, color, group, compressed_indices, additional_info
98+
A, ag, color, group, compressed_indices, decompression_uplo, additional_info
9599
)
96100
end
97101

0 commit comments

Comments
 (0)