Skip to content

Commit 01e695c

Browse files
committed
Single type
1 parent 48202a6 commit 01e695c

6 files changed

Lines changed: 89 additions & 90 deletions

File tree

src/adtypes.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ function coloring(
55
kwargs...,
66
)
77
bg = BipartiteGraph(A)
8-
color = convert(Vector{Int}, ADTypes.column_coloring(A, algo))
8+
color = convert(Vector{eltype(bg)}, ADTypes.column_coloring(A, algo))
99
return ColumnColoringResult(A, bg, color)
1010
end
1111

@@ -16,6 +16,6 @@ function coloring(
1616
kwargs...,
1717
)
1818
bg = BipartiteGraph(A)
19-
color = convert(Vector{Int}, ADTypes.row_coloring(A, algo))
19+
color = convert(Vector{eltype(bg)}, ADTypes.row_coloring(A, algo))
2020
return RowColoringResult(A, bg, color)
2121
end

src/coloring.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ function star_coloring(
8787
first_neighbor = fill((zero(T), zero(T), zero(T)), nv) # at first no neighbors have been encountered
8888
treated = zeros(T, nv)
8989
star = Vector{T}(undef, ne)
90-
hub = Int[] # one hub for each star, including the trivial ones
90+
hub = T[] # one hub for each star, including the trivial ones
9191
vertices_in_order = vertices(g, order)
9292

9393
for v in vertices_in_order

src/constant.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,10 @@ julia> column_colors(result)
6565
- [`ADTypes.row_coloring`](@extref ADTypes.row_coloring)
6666
"""
6767
struct ConstantColoringAlgorithm{
68-
partition,M<:AbstractMatrix,T,R<:AbstractColoringResult{:nonsymmetric,partition,:direct}
68+
partition,
69+
M<:AbstractMatrix,
70+
T<:Integer,
71+
R<:AbstractColoringResult{:nonsymmetric,partition,:direct},
6972
} <: ADTypes.AbstractColoringAlgorithm
7073
matrix_template::M
7174
color::Vector{T}
@@ -77,7 +80,7 @@ function ConstantColoringAlgorithm{:column}(
7780
)
7881
bg = BipartiteGraph(matrix_template)
7982
result = ColumnColoringResult(matrix_template, bg, color)
80-
T, M, R = eltype(color), typeof(matrix_template), typeof(result)
83+
T, M, R = eltype(bg), typeof(matrix_template), typeof(result)
8184
return ConstantColoringAlgorithm{:column,M,T,R}(matrix_template, color, result)
8285
end
8386

@@ -86,7 +89,7 @@ function ConstantColoringAlgorithm{:row}(
8689
)
8790
bg = BipartiteGraph(matrix_template)
8891
result = RowColoringResult(matrix_template, bg, color)
89-
T, M, R = eltype(color), typeof(matrix_template), typeof(result)
92+
T, M, R = eltype(bg), typeof(matrix_template), typeof(result)
9093
return ConstantColoringAlgorithm{:row,M,T,R}(matrix_template, color, result)
9194
end
9295

src/decompression.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -677,12 +677,12 @@ function decompress!(
677677
result::LinearSystemColoringResult,
678678
uplo::Symbol=:F,
679679
)
680-
(; color, strict_upper_nonzero_inds, T_factorization, strict_upper_nonzeros_A) = result
680+
(; color, strict_upper_nonzero_inds, M_factorization, strict_upper_nonzeros_A) = result
681681
S = result.ag.S
682682
uplo == :F && check_same_pattern(A, S)
683683

684684
# TODO: for some reason I cannot use ldiv! with a sparse QR
685-
strict_upper_nonzeros_A = T_factorization \ vec(B)
685+
strict_upper_nonzeros_A = M_factorization \ vec(B)
686686
fill!(A, zero(eltype(A)))
687687
for i in axes(A, 1)
688688
if !iszero(S[i, i])

src/graph.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ end
2323

2424
SparsityPatternCSC(A::SparseMatrixCSC) = SparsityPatternCSC(A.m, A.n, A.colptr, A.rowval)
2525

26+
Base.eltype(::SparsityPatternCSC{Ti}) where {Ti} = Ti
2627
Base.size(S::SparsityPatternCSC) = (S.m, S.n)
2728
Base.size(S::SparsityPatternCSC, d::Integer) = d::Integer <= 2 ? size(S)[d] : 1
2829
Base.axes(S::SparsityPatternCSC, d::Integer) = Base.OneTo(size(S, d))
@@ -227,6 +228,8 @@ struct AdjacencyGraph{T<:Integer,has_diagonal}
227228
edge_to_index::Vector{T}
228229
end
229230

231+
Base.eltype(::AdjacencyGraph{T}) where {T} = T
232+
230233
function AdjacencyGraph(
231234
S::SparsityPatternCSC{T},
232235
edge_to_index::Vector{T}=build_edge_to_index(S);
@@ -343,6 +346,8 @@ struct BipartiteGraph{T<:Integer}
343346
S2::SparsityPatternCSC{T}
344347
end
345348

349+
Base.eltype(::BipartiteGraph{T}) where {T} = T
350+
346351
function BipartiteGraph(A::AbstractMatrix; symmetric_pattern::Bool=false)
347352
return BipartiteGraph(SparseMatrixCSC(A); symmetric_pattern)
348353
end

0 commit comments

Comments
 (0)