Skip to content

Commit a4e026a

Browse files
committed
Check if a buffer for the decompression of acyclic coloring is needed
1 parent 87282ee commit a4e026a

3 files changed

Lines changed: 22 additions & 3 deletions

File tree

src/decompression.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -539,7 +539,7 @@ function decompress!(
539539
R = eltype(A)
540540
fill!(A, zero(R))
541541

542-
if eltype(buffer) == R
542+
if eltype(buffer) == R || isempty(buffer)
543543
buffer_right_type = buffer
544544
else
545545
buffer_right_type = similar(buffer, R)
@@ -615,7 +615,7 @@ function decompress!(
615615
nzA = nonzeros(A)
616616
uplo == :F && check_same_pattern(A, S)
617617

618-
if eltype(buffer) == R
618+
if eltype(buffer) == R || isempty(buffer)
619619
buffer_right_type = buffer
620620
else
621621
buffer_right_type = similar(buffer, R)

src/result.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -362,7 +362,8 @@ function TreeSetColoringResult(
362362

363363
# buffer holds the sum of edge values for subtrees in a tree.
364364
# For each vertex i, buffer[i] is the sum of edge values in the subtree rooted at i.
365-
buffer = Vector{R}(undef, nvertices)
365+
# Note that we don't need a buffer is all trees are stars.
366+
buffer = all(is_star) ? R[] : Vector{R}(undef, nvertices)
366367

367368
return TreeSetColoringResult(
368369
A,

test/allocations.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,3 +132,21 @@ end;
132132
test_noallocs_structured_decompression(1000; structure, partition, decompression)
133133
end
134134
end;
135+
136+
@testset "Multi-precision acyclic decompression" begin
137+
@testset "$format" for format in ("dense", "sparse")
138+
A = [0 0 1; 0 1 0; 1 0 0]
139+
if format == "sparse"
140+
A = sparse(A)
141+
end
142+
problem = ColoringProblem(; structure=:symmetric, partition=:column)
143+
result = coloring(A, problem, GreedyColoringAlgorithm{:substitution}())
144+
@test isempty(result.buffer)
145+
for T in (Float32, Float64)
146+
C = rand(T) * T.(A)
147+
B = compress(C, result)
148+
bench_multiprecision = @be decompress!(C, B, result)
149+
@test minimum(bench_multiprecision).allocs == 0
150+
end
151+
end
152+
end

0 commit comments

Comments
 (0)