Skip to content

Commit da60d38

Browse files
committed
Propagate arbitrary integer types as indices
1 parent b465a79 commit da60d38

6 files changed

Lines changed: 200 additions & 151 deletions

File tree

src/check.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,9 @@ This is done by checking, for each ordered vertex, that its back- or forward-deg
273273
This function is not coded with efficiency in mind, it is designed for small-scale tests.
274274
"""
275275
function valid_dynamic_order(
276-
g::AdjacencyGraph, π::AbstractVector{Int}, ::DynamicDegreeBasedOrder{degtype,direction}
276+
g::AdjacencyGraph,
277+
π::AbstractVector{<:Integer},
278+
::DynamicDegreeBasedOrder{degtype,direction},
277279
) where {degtype,direction}
278280
length(π) != nb_vertices(g) && return false
279281
length(unique(π)) != nb_vertices(g) && return false
@@ -300,7 +302,7 @@ end
300302
function valid_dynamic_order(
301303
g::BipartiteGraph,
302304
::Val{side},
303-
π::AbstractVector{Int},
305+
π::AbstractVector{<:Integer},
304306
::DynamicDegreeBasedOrder{degtype,direction},
305307
) where {side,degtype,direction}
306308
length(π) != nb_vertices(g, Val(side)) && return false

src/coloring.jl

Lines changed: 34 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,18 @@ The vertices are colored in a greedy fashion, following the `order` supplied.
1717
> [_What Color Is Your Jacobian? Graph Coloring for Computing Derivatives_](https://epubs.siam.org/doi/10.1137/S0036144504444711), Gebremedhin et al. (2005), Algorithm 3.2
1818
"""
1919
function partial_distance2_coloring(
20-
bg::BipartiteGraph, ::Val{side}, order::AbstractOrder
21-
) where {side}
22-
color = Vector{Int}(undef, nb_vertices(bg, Val(side)))
23-
forbidden_colors = Vector{Int}(undef, nb_vertices(bg, Val(side)))
20+
bg::BipartiteGraph{T}, ::Val{side}, order::AbstractOrder
21+
) where {T,side}
22+
color = Vector{T}(undef, nb_vertices(bg, Val(side)))
23+
forbidden_colors = Vector{T}(undef, nb_vertices(bg, Val(side)))
2424
vertices_in_order = vertices(bg, Val(side), order)
2525
partial_distance2_coloring!(color, forbidden_colors, bg, Val(side), vertices_in_order)
2626
return color
2727
end
2828

2929
function partial_distance2_coloring!(
30-
color::Vector{Int},
31-
forbidden_colors::Vector{Int},
30+
color::AbstractVector{<:Integer},
31+
forbidden_colors::AbstractVector{<:Integer},
3232
bg::BipartiteGraph,
3333
::Val{side},
3434
vertices_in_order::AbstractVector{<:Integer},
@@ -76,15 +76,17 @@ If `postprocessing=true`, some colors might be replaced with `0` (the "neutral"
7676
7777
> [_New Acyclic and Star Coloring Algorithms with Application to Computing Hessians_](https://epubs.siam.org/doi/abs/10.1137/050639879), Gebremedhin et al. (2007), Algorithm 4.1
7878
"""
79-
function star_coloring(g::AdjacencyGraph, order::AbstractOrder, postprocessing::Bool)
79+
function star_coloring(
80+
g::AdjacencyGraph{T}, order::AbstractOrder, postprocessing::Bool
81+
) where {T<:Integer}
8082
# Initialize data structures
8183
nv = nb_vertices(g)
8284
ne = nb_edges(g)
83-
color = zeros(Int, nv)
84-
forbidden_colors = zeros(Int, nv)
85-
first_neighbor = fill((0, 0, 0), nv) # at first no neighbors have been encountered
86-
treated = zeros(Int, nv)
87-
star = Vector{Int}(undef, ne)
85+
color = zeros(T, nv)
86+
forbidden_colors = zeros(T, nv)
87+
first_neighbor = fill((zero(T), zero(T), zero(T)), nv) # at first no neighbors have been encountered
88+
treated = zeros(T, nv)
89+
star = Vector{T}(undef, ne)
8890
hub = Int[] # one hub for each star, including the trivial ones
8991
vertices_in_order = vertices(g, order)
9092

@@ -196,11 +198,11 @@ Encode a set of 2-colored stars resulting from the [`star_coloring`](@ref) algor
196198
197199
$TYPEDFIELDS
198200
"""
199-
struct StarSet
201+
struct StarSet{T}
200202
"a mapping from edges (pair of vertices) to their star index"
201-
star::Vector{Int}
203+
star::Vector{T}
202204
"a mapping from star indices to their hub (undefined hubs for single-edge stars are the negative value of one of the vertices, picked arbitrarily)"
203-
hub::Vector{Int}
205+
hub::Vector{T}
204206
end
205207

206208
"""
@@ -226,15 +228,17 @@ If `postprocessing=true`, some colors might be replaced with `0` (the "neutral"
226228
227229
> [_New Acyclic and Star Coloring Algorithms with Application to Computing Hessians_](https://epubs.siam.org/doi/abs/10.1137/050639879), Gebremedhin et al. (2007), Algorithm 3.1
228230
"""
229-
function acyclic_coloring(g::AdjacencyGraph, order::AbstractOrder, postprocessing::Bool)
231+
function acyclic_coloring(
232+
g::AdjacencyGraph{T}, order::AbstractOrder, postprocessing::Bool
233+
) where {T<:Integer}
230234
# Initialize data structures
231235
nv = nb_vertices(g)
232236
ne = nb_edges(g)
233-
color = zeros(Int, nv)
234-
forbidden_colors = zeros(Int, nv)
235-
first_neighbor = fill((0, 0, 0), nv) # at first no neighbors have been encountered
236-
first_visit_to_tree = fill((0, 0), ne)
237-
forest = Forest{Int}(ne)
237+
color = zeros(T, nv)
238+
forbidden_colors = zeros(T, nv)
239+
first_neighbor = fill((zero(T), zero(T), zero(T)), nv) # at first no neighbors have been encountered
240+
first_visit_to_tree = fill((zero(T), zero(T)), ne)
241+
forest = Forest{T}(ne)
238242
vertices_in_order = vertices(g, order)
239243

240244
for v in vertices_in_order
@@ -367,23 +371,23 @@ Encode a set of 2-colored trees resulting from the [`acyclic_coloring`](@ref) al
367371
368372
$TYPEDFIELDS
369373
"""
370-
struct TreeSet
371-
reverse_bfs_orders::Vector{Vector{Tuple{Int,Int}}}
374+
struct TreeSet{T}
375+
reverse_bfs_orders::Vector{Vector{Tuple{T,T}}}
372376
is_star::Vector{Bool}
373377
end
374378

375-
function TreeSet(g::AdjacencyGraph, forest::Forest{Int})
379+
function TreeSet(g::AdjacencyGraph{T}, forest::Forest) where {T}
376380
S = pattern(g)
377381
edge_to_index = edge_indices(g)
378382
nv = nb_vertices(g)
379383
nt = forest.num_trees
380384

381385
# dictionary that maps a tree's root to the index of the tree
382-
roots = Dict{Int,Int}()
386+
roots = Dict{T,T}()
383387
sizehint!(roots, nt)
384388

385389
# vector of dictionaries where each dictionary stores the neighbors of each vertex in a tree
386-
trees = [Dict{Int,Vector{Int}}() for i in 1:nt]
390+
trees = [Dict{T,Vector{T}}() for i in 1:nt]
387391

388392
# current number of roots found
389393
nr = 0
@@ -423,10 +427,10 @@ function TreeSet(g::AdjacencyGraph, forest::Forest{Int})
423427
end
424428

425429
# degrees is a vector of integers that stores the degree of each vertex in a tree
426-
degrees = Vector{Int}(undef, nv)
430+
degrees = Vector{T}(undef, nv)
427431

428432
# reverse breadth first (BFS) traversal order for each tree in the forest
429-
reverse_bfs_orders = [Tuple{Int,Int}[] for i in 1:nt]
433+
reverse_bfs_orders = [Tuple{T,T}[] for i in 1:nt]
430434

431435
# nvmax is the number of vertices of the biggest tree in the forest
432436
nvmax = 0
@@ -436,7 +440,7 @@ function TreeSet(g::AdjacencyGraph, forest::Forest{Int})
436440
end
437441

438442
# Create a queue with a fixed size nvmax
439-
queue = Vector{Int}(undef, nvmax)
443+
queue = Vector{T}(undef, nvmax)
440444

441445
# Specify if each tree in the forest is a star,
442446
# meaning that one vertex is directly connected to all other vertices in the tree
@@ -519,7 +523,7 @@ function postprocess!(
519523
color::AbstractVector{<:Integer},
520524
star_or_tree_set::Union{StarSet,TreeSet},
521525
g::AdjacencyGraph,
522-
offsets::Vector{Int},
526+
offsets::AbstractVector{<:Integer},
523527
)
524528
S = pattern(g)
525529
edge_to_index = edge_indices(g)

src/constant.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ Indeed, for symmetric coloring problems, we need more than just the vector of co
1414
1515
- `partition::Symbol`: either `:row` or `:column`.
1616
- `matrix_template::AbstractMatrix`: matrix for which the vector of colors was precomputed (the algorithm will only accept matrices of the exact same size).
17-
- `color::Vector{Int}`: vector of integer colors, one for each row or column (depending on `partition`).
17+
- `color::Vector{<:Integer}`: vector of integer colors, one for each row or column (depending on `partition`).
1818
1919
!!! warning
2020
The second constructor (based on keyword arguments) is type-unstable.
@@ -65,33 +65,33 @@ julia> column_colors(result)
6565
- [`ADTypes.row_coloring`](@extref ADTypes.row_coloring)
6666
"""
6767
struct ConstantColoringAlgorithm{
68-
partition,M<:AbstractMatrix,R<:AbstractColoringResult{:nonsymmetric,partition,:direct}
68+
partition,M<:AbstractMatrix,T,R<:AbstractColoringResult{:nonsymmetric,partition,:direct}
6969
} <: ADTypes.AbstractColoringAlgorithm
7070
matrix_template::M
71-
color::Vector{Int}
71+
color::Vector{T}
7272
result::R
7373
end
7474

7575
function ConstantColoringAlgorithm{:column}(
76-
matrix_template::AbstractMatrix, color::Vector{Int}
76+
matrix_template::AbstractMatrix, color::Vector{<:Integer}
7777
)
7878
bg = BipartiteGraph(matrix_template)
7979
result = ColumnColoringResult(matrix_template, bg, color)
80-
M, R = typeof(matrix_template), typeof(result)
81-
return ConstantColoringAlgorithm{:column,M,R}(matrix_template, color, result)
80+
T, M, R = eltype(color), typeof(matrix_template), typeof(result)
81+
return ConstantColoringAlgorithm{:column,M,T,R}(matrix_template, color, result)
8282
end
8383

8484
function ConstantColoringAlgorithm{:row}(
85-
matrix_template::AbstractMatrix, color::Vector{Int}
85+
matrix_template::AbstractMatrix, color::Vector{<:Integer}
8686
)
8787
bg = BipartiteGraph(matrix_template)
8888
result = RowColoringResult(matrix_template, bg, color)
89-
M, R = typeof(matrix_template), typeof(result)
90-
return ConstantColoringAlgorithm{:row,M,R}(matrix_template, color, result)
89+
T, M, R = eltype(color), typeof(matrix_template), typeof(result)
90+
return ConstantColoringAlgorithm{:row,M,T,R}(matrix_template, color, result)
9191
end
9292

9393
function ConstantColoringAlgorithm(
94-
matrix_template::AbstractMatrix, color::Vector{Int}; partition=:column
94+
matrix_template::AbstractMatrix, color::Vector{<:Integer}; partition::Symbol=:column
9595
)
9696
return ConstantColoringAlgorithm{partition}(matrix_template, color)
9797
end

src/graph.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ end
2424
SparsityPatternCSC(A::SparseMatrixCSC) = SparsityPatternCSC(A.m, A.n, A.colptr, A.rowval)
2525

2626
Base.size(S::SparsityPatternCSC) = (S.m, S.n)
27-
Base.size(S::SparsityPatternCSC, d) = d::Integer <= 2 ? size(S)[d] : 1
27+
Base.size(S::SparsityPatternCSC, d::Integer) = d::Integer <= 2 ? size(S)[d] : 1
2828
Base.axes(S::SparsityPatternCSC, d::Integer) = Base.OneTo(size(S, d))
2929

3030
SparseArrays.nnz(S::SparsityPatternCSC) = length(S.rowval)
@@ -222,7 +222,7 @@ The adjacency graph of a symmetric matrix `A ∈ ℝ^{n × n}` is `G(A) = (V, E)
222222
223223
> [_What Color Is Your Jacobian? SparsityPatternCSC Coloring for Computing Derivatives_](https://epubs.siam.org/doi/10.1137/S0036144504444711), Gebremedhin et al. (2005)
224224
"""
225-
struct AdjacencyGraph{T,has_diagonal}
225+
struct AdjacencyGraph{T<:Integer,has_diagonal}
226226
S::SparsityPatternCSC{T}
227227
edge_to_index::Vector{T}
228228
end
@@ -298,7 +298,7 @@ function has_neighbor(g::AdjacencyGraph, v::Integer, u::Integer)
298298
return false
299299
end
300300

301-
function degree_in_subset(g::AdjacencyGraph, v::Integer, subset::AbstractVector{Int})
301+
function degree_in_subset(g::AdjacencyGraph, v::Integer, subset::AbstractVector{<:Integer})
302302
d = 0
303303
for u in subset
304304
if has_neighbor(g, v, u)
@@ -338,7 +338,7 @@ When `symmetric_pattern` is `true`, this construction is more efficient.
338338
339339
> [_What Color Is Your Jacobian? SparsityPatternCSC Coloring for Computing Derivatives_](https://epubs.siam.org/doi/10.1137/S0036144504444711), Gebremedhin et al. (2005)
340340
"""
341-
struct BipartiteGraph{T}
341+
struct BipartiteGraph{T<:Integer}
342342
S1::SparsityPatternCSC{T}
343343
S2::SparsityPatternCSC{T}
344344
end
@@ -425,7 +425,7 @@ function has_neighbor_dist2(
425425
end
426426

427427
function degree_dist2_in_subset(
428-
bg::BipartiteGraph, ::Val{side}, v::Integer, subset::AbstractVector{Int}
428+
bg::BipartiteGraph, ::Val{side}, v::Integer, subset::AbstractVector{<:Integer}
429429
) where {side}
430430
d = 0
431431
for u in subset

src/order.jl

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -84,11 +84,11 @@ function vertices(g::AdjacencyGraph, ::LargestFirst)
8484
return sort(vertices(g); by=criterion, rev=true)
8585
end
8686

87-
function vertices(bg::BipartiteGraph, ::Val{side}, ::LargestFirst) where {side}
87+
function vertices(bg::BipartiteGraph{T}, ::Val{side}, ::LargestFirst) where {T,side}
8888
other_side = 3 - side
8989
n = nb_vertices(bg, Val(side))
9090
visited = falses(n) # necessary for distance-2 neighborhoods
91-
degrees_dist2 = zeros(Int, n)
91+
degrees_dist2 = zeros(T, n)
9292
for v in vertices(bg, Val(side))
9393
fill!(visited, false)
9494
for u in neighbors(bg, Val(side), v)
@@ -129,27 +129,27 @@ Instance of [`AbstractOrder`](@ref) which sorts vertices using a dynamically com
129129
"""
130130
struct DynamicDegreeBasedOrder{degtype,direction} <: AbstractOrder end
131131

132-
struct DegreeBuckets
133-
degrees::Vector{Int}
134-
bucket_storage::Vector{Int}
135-
bucket_low::Vector{Int}
136-
bucket_high::Vector{Int}
137-
positions::Vector{Int}
132+
struct DegreeBuckets{T}
133+
degrees::Vector{T}
134+
bucket_storage::Vector{T}
135+
bucket_low::Vector{T}
136+
bucket_high::Vector{T}
137+
positions::Vector{T}
138138
end
139139

140-
function DegreeBuckets(degrees::Vector{Int}, dmax)
140+
function DegreeBuckets(::Type{T}, degrees::Vector{<:Integer}, dmax::Integer) where {T}
141141
# number of vertices per degree class
142-
deg_count = zeros(Int, dmax + 1)
142+
deg_count = zeros(T, dmax + 1)
143143
for d in degrees
144144
deg_count[d + 1] += 1
145145
end
146146
# bucket limits
147147
bucket_high = cumsum(deg_count)
148-
bucket_low = vcat(0, @view(bucket_high[1:(end - 1)]))
148+
bucket_low = vcat(zero(T), @view(bucket_high[1:(end - 1)]))
149149
bucket_low .+= 1
150150
# assign each vertex to the correct position inside its degree class
151-
bucket_storage = similar(degrees, Int)
152-
positions = similar(degrees, Int)
151+
bucket_storage = similar(degrees, T)
152+
positions = similar(degrees, T)
153153
for v in eachindex(positions, degrees)
154154
d = degrees[v]
155155
positions[v] = bucket_high[d + 1] - deg_count[d + 1] + 1
@@ -168,9 +168,9 @@ function degree_increasing(; degtype, direction)
168168
return increasing
169169
end
170170

171-
function mark_ordered!(db::DegreeBuckets, v::Integer)
171+
function mark_ordered!(db::DegreeBuckets{T}, v::Integer) where {T}
172172
db.degrees[v] = -1
173-
db.positions[v] = typemin(Int)
173+
db.positions[v] = typemin(T)
174174
return nothing
175175
end
176176

@@ -248,15 +248,15 @@ function update_bucket!(db::DegreeBuckets, v::Integer; degtype, direction)
248248
end
249249

250250
function vertices(
251-
g::AdjacencyGraph, ::DynamicDegreeBasedOrder{degtype,direction}
252-
) where {degtype,direction}
251+
g::AdjacencyGraph{T}, ::DynamicDegreeBasedOrder{degtype,direction}
252+
) where {T<:Integer,degtype,direction}
253253
if degree_increasing(; degtype, direction)
254-
degrees = zeros(Int, nb_vertices(g))
254+
degrees = zeros(T, nb_vertices(g))
255255
else
256256
degrees = [degree(g, v) for v in vertices(g)]
257257
end
258-
db = DegreeBuckets(degrees, maximum_degree(g))
259-
π = Int[]
258+
db = DegreeBuckets(T, degrees, maximum_degree(g))
259+
π = T[]
260260
sizehint!(π, nb_vertices(g))
261261
for _ in 1:nb_vertices(g)
262262
u = pop_next_candidate!(db; direction)
@@ -271,12 +271,12 @@ function vertices(
271271
end
272272

273273
function vertices(
274-
g::BipartiteGraph, ::Val{side}, ::DynamicDegreeBasedOrder{degtype,direction}
275-
) where {side,degtype,direction}
274+
g::BipartiteGraph{T}, ::Val{side}, ::DynamicDegreeBasedOrder{degtype,direction}
275+
) where {T<:Integer,side,degtype,direction}
276276
other_side = 3 - side
277277
# compute dist-2 degrees in an optimized way
278278
n = nb_vertices(g, Val(side))
279-
degrees_dist2 = zeros(Int, n)
279+
degrees_dist2 = zeros(T, n)
280280
dist2_neighbor = falses(n)
281281
for v in vertices(g, Val(side))
282282
fill!(dist2_neighbor, false)
@@ -288,13 +288,13 @@ function vertices(
288288
degrees_dist2[v] = sum(dist2_neighbor)
289289
end
290290
if degree_increasing(; degtype, direction)
291-
degrees = zeros(Int, n)
291+
degrees = zeros(T, n)
292292
else
293293
degrees = degrees_dist2
294294
end
295295
maxd2 = maximum(degrees_dist2)
296-
db = DegreeBuckets(degrees, maxd2)
297-
π = Int[]
296+
db = DegreeBuckets(T, degrees, maxd2)
297+
π = T[]
298298
sizehint!(π, n)
299299
visited = falses(n)
300300
for _ in 1:nb_vertices(g, Val(side))

0 commit comments

Comments
 (0)