Skip to content

Commit a430140

Browse files
committed
Optimized implementation for structured
1 parent f3531af commit a430140

7 files changed

Lines changed: 408 additions & 2 deletions

Project.toml

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,33 @@ DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
1010
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
1111
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1212
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
13+
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
1314
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1415

16+
[weakdeps]
17+
BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
18+
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
19+
BlockBandedMatrices = "ffab5731-97b5-5995-9138-79e8c1846df0"
20+
21+
[extensions]
22+
SparseMatrixColoringsBandedMatricesExt = "BandedMatrices"
23+
SparseMatrixColoringsBlockBandedMatricesExt = ["BlockArrays", "BlockBandedMatrices"]
24+
1525
[compat]
1626
ADTypes = "1.2.1"
27+
BandedMatrices = "1.7.5"
28+
BlockArrays = "1.1.1"
29+
BlockBandedMatrices = "0.13.1"
1730
Compat = "3.46,4.2"
1831
DataStructures = "0.18"
19-
LinearAlgebra = "<0.0.1, 1"
2032
DocStringExtensions = "0.8,0.9"
33+
LinearAlgebra = "<0.0.1, 1"
2134
Random = "<0.0.1, 1"
35+
Requires = "1.3.0"
2236
SparseArrays = "<0.0.1, 1"
2337
julia = "1.6"
38+
39+
[extras]
40+
BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
41+
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
42+
BlockBandedMatrices = "ffab5731-97b5-5995-9138-79e8c1846df0"

docs/src/dev.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,3 +65,9 @@ SparseMatrixColorings.what_fig_61
6565
SparseMatrixColorings.efficient_fig_1
6666
SparseMatrixColorings.efficient_fig_4
6767
```
68+
69+
## Misc
70+
71+
```@docs
72+
SparseMatrixColorings.cycle_range
73+
```
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
module SparseMatrixColoringsBandedMatricesExt
2+
3+
if isdefined(Base, :get_extension)
4+
using BandedMatrices: BandedMatrix, bandrange, bandwidths, colrange, rowrange
5+
using SparseMatrixColorings:
6+
BipartiteGraph,
7+
ColoringProblem,
8+
ColumnColoringResult,
9+
GreedyColoringAlgorithm,
10+
RowColoringResult,
11+
column_colors,
12+
cycle_range,
13+
row_colors
14+
import SparseMatrixColorings as SMC
15+
else
16+
using ..BandedMatrices: BandedMatrix, bandrange, bandwidths, colrange, rowrange
17+
using ..SparseMatrixColorings:
18+
BipartiteGraph,
19+
ColoringProblem,
20+
ColumnColoringResult,
21+
GreedyColoringAlgorithm,
22+
RowColoringResult,
23+
column_colors,
24+
cycle_range,
25+
row_colors
26+
import ..SparseMatrixColorings as SMC
27+
end
28+
29+
#=
30+
This code is partly taken from ArrayInterface.jl and FiniteDiff.jl
31+
https://github.com/JuliaArrays/ArrayInterface.jl
32+
https://github.com/JuliaDiff/FiniteDiff.jl
33+
=#
34+
35+
function SMC.coloring(
36+
A::BandedMatrix,
37+
::ColoringProblem{:nonsymmetric,:column},
38+
algo::GreedyColoringAlgorithm;
39+
kwargs...,
40+
)
41+
width = length(bandrange(A))
42+
color = cycle_range(width, size(A, 2))
43+
bg = BipartiteGraph(A)
44+
return ColumnColoringResult(A, bg, color)
45+
end
46+
47+
function SMC.coloring(
48+
A::BandedMatrix,
49+
::ColoringProblem{:nonsymmetric,:row},
50+
algo::GreedyColoringAlgorithm;
51+
kwargs...,
52+
)
53+
width = length(bandrange(A))
54+
color = cycle_range(width, size(A, 1))
55+
bg = BipartiteGraph(A)
56+
return RowColoringResult(A, bg, color)
57+
end
58+
59+
function SMC.decompress!(A::BandedMatrix, B::AbstractMatrix, result::ColumnColoringResult)
60+
color = column_colors(result)
61+
for j in axes(A, 2)
62+
c = color[j]
63+
for i in colrange(A, j)
64+
A[i, j] = B[i, c]
65+
end
66+
end
67+
return A
68+
end
69+
70+
function SMC.decompress!(A::BandedMatrix, B::AbstractMatrix, result::RowColoringResult)
71+
color = row_colors(result)
72+
for i in axes(A, 1)
73+
c = color[i]
74+
for j in rowrange(A, i)
75+
A[i, j] = B[c, j]
76+
end
77+
end
78+
return A
79+
end
80+
81+
end
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
module SparseMatrixColoringsBlockBandedMatricesExt
2+
3+
if isdefined(Base, :get_extension)
4+
using BlockArrays: blockaxes, blockfirsts, blocklasts, blocksize, blocklengths
5+
using BlockBandedMatrices:
6+
BandedBlockBandedMatrix,
7+
BlockBandedMatrix,
8+
blockbandrange,
9+
blockbandwidths,
10+
blocklengths,
11+
blocksize,
12+
subblockbandwidths
13+
using SparseMatrixColorings:
14+
BipartiteGraph,
15+
ColoringProblem,
16+
ColumnColoringResult,
17+
GreedyColoringAlgorithm,
18+
RowColoringResult,
19+
column_colors,
20+
cycle_range,
21+
row_colors
22+
import SparseMatrixColorings as SMC
23+
else
24+
using ..BlockArrays: blockaxes, blockfirsts, blocklasts, blocksize, blocklengths
25+
using ..BlockBandedMatrices:
26+
BandedBlockBandedMatrix,
27+
BlockBandedMatrix,
28+
blockbandrange,
29+
blockbandwidths,
30+
blocklengths,
31+
blocksize,
32+
subblockbandwidths
33+
using ..SparseMatrixColorings:
34+
BipartiteGraph,
35+
ColoringProblem,
36+
ColumnColoringResult,
37+
GreedyColoringAlgorithm,
38+
RowColoringResult,
39+
column_colors,
40+
cycle_range,
41+
row_colors
42+
import ..SparseMatrixColorings as SMC
43+
end
44+
45+
#=
46+
This code is partly taken from ArrayInterface.jl and FiniteDiff.jl
47+
https://github.com/JuliaArrays/ArrayInterface.jl
48+
https://github.com/JuliaDiff/FiniteDiff.jl
49+
=#
50+
51+
function subblockbandrange(A::BandedBlockBandedMatrix)
52+
u, l = subblockbandwidths(A)
53+
return (-l):u
54+
end
55+
56+
function blockbanded_coloring(
57+
A::Union{BlockBandedMatrix,BandedBlockBandedMatrix}, dim::Integer
58+
)
59+
# consider blocks of columns or rows (let's call them vertices) depending on `dim`
60+
nb_blocks = blocksize(A, dim)
61+
nb_in_block = blocklengths(axes(A, dim))
62+
first_in_block = blockfirsts(axes(A, dim))
63+
last_in_block = blocklasts(axes(A, dim))
64+
color = zeros(Int, size(A, dim))
65+
66+
# give a macroscopic color to each block, so that 2 blocks with the same macro color are orthogonal
67+
# same idea as for BandedMatrices
68+
nb_macrocolors = length(blockbandrange(A))
69+
macrocolor = cycle_range(nb_macrocolors, nb_blocks)
70+
71+
width = if A isa BandedBlockBandedMatrix
72+
# vertices within a block are colored cleverly using bands
73+
length(subblockbandrange(A))
74+
else
75+
# vertices within a block are colored naively with distinct micro colors (~ infinite band width)
76+
typemax(Int)
77+
end
78+
79+
# for each macroscopic color, count how many microscopic colors will be needed
80+
nb_colors_in_macrocolor = zeros(Int, nb_macrocolors)
81+
for mc in 1:nb_macrocolors
82+
largest_nb_in_macrocolor = maximum(nb_in_block[mc:nb_macrocolors:nb_blocks]; init=0)
83+
nb_colors_in_macrocolor[mc] = min(width, largest_nb_in_macrocolor)
84+
end
85+
color_shift_in_macrocolor = vcat(0, cumsum(nb_colors_in_macrocolor)[1:(end - 1)])
86+
87+
# assign a microscopic color to each column as a function of its macroscopic color and its position within the block
88+
for b in 1:nb_blocks
89+
block_color = cycle_range(width, nb_in_block[b])
90+
shift = color_shift_in_macrocolor[macrocolor[b]]
91+
color[first_in_block[b]:last_in_block[b]] .= shift .+ block_color
92+
end
93+
94+
return color
95+
end
96+
97+
function SMC.coloring(
98+
A::Union{BlockBandedMatrix,BandedBlockBandedMatrix},
99+
::ColoringProblem{:nonsymmetric,:column},
100+
algo::GreedyColoringAlgorithm;
101+
kwargs...,
102+
)
103+
color = blockbanded_coloring(A, 2)
104+
bg = BipartiteGraph(A)
105+
return ColumnColoringResult(A, bg, color)
106+
end
107+
108+
function SMC.coloring(
109+
A::Union{BlockBandedMatrix,BandedBlockBandedMatrix},
110+
::ColoringProblem{:nonsymmetric,:row},
111+
algo::GreedyColoringAlgorithm;
112+
kwargs...,
113+
)
114+
color = blockbanded_coloring(A, 1)
115+
bg = BipartiteGraph(A)
116+
return RowColoringResult(A, bg, color)
117+
end
118+
119+
end

src/SparseMatrixColorings.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,13 @@ using DataStructures: DisjointSets, find_root!, root_union!, num_groups
1616
using DocStringExtensions: README, EXPORTS, SIGNATURES, TYPEDEF, TYPEDFIELDS
1717
using LinearAlgebra:
1818
Adjoint,
19+
Bidiagonal,
1920
Diagonal,
2021
Hermitian,
2122
LowerTriangular,
2223
Symmetric,
2324
Transpose,
25+
Tridiagonal,
2426
UpperTriangular,
2527
adjoint,
2628
checksquare,
@@ -51,6 +53,7 @@ include("matrices.jl")
5153
include("interface.jl")
5254
include("constant.jl")
5355
include("decompression.jl")
56+
include("structured.jl")
5457
include("check.jl")
5558
include("examples.jl")
5659

@@ -63,4 +66,19 @@ export column_groups, row_groups
6366
export sparsity_pattern
6467
export compress, decompress, decompress!, decompress_single_color!
6568

69+
if !isdefined(Base, :get_extension)
70+
using Requires
71+
end
72+
73+
@static if !isdefined(Base, :get_extension)
74+
function __init__()
75+
@require BandedMatrices = "aae01518-5342-5314-be14-df237901396f" include(
76+
"../ext/SparseMatrixColoringsBandedMatricesExt.jl"
77+
)
78+
@require BlockBandedMatrices = "ffab5731-97b5-5995-9138-79e8c1846df0" include(
79+
"../ext/SparseMatrixColoringsBlockBandedMatricesExt.jl"
80+
)
81+
end
82+
end
83+
6684
end

0 commit comments

Comments
 (0)