Skip to content

Fix waittask#3000

Open
wsmoses wants to merge 1 commit intomainfrom
waittask
Open

Fix waittask#3000
wsmoses wants to merge 1 commit intomainfrom
waittask

Conversation

@wsmoses
Copy link
Copy Markdown
Member

@wsmoses wsmoses commented Mar 17, 2026

No description provided.

@github-actions
Copy link
Copy Markdown
Contributor

Your PR requires formatting changes to meet the project's style guidelines.
Please consider running Runic (git runic main) to apply these changes.

Click here to view the suggested changes.
diff --git a/src/internal_rules/core.jl b/src/internal_rules/core.jl
index ccaee4bc..03c5f78e 100644
--- a/src/internal_rules/core.jl
+++ b/src/internal_rules/core.jl
@@ -415,22 +415,22 @@ end
 # -------------------- Base.wait EnzymeRules --------------------
 
 function EnzymeRules.forward(
-    config::EnzymeRules.FwdConfig,
-    ::Const{typeof(Base.wait)},
-    ::Type{<:Const},
-    t::Duplicated{<:Task},
-)
+        config::EnzymeRules.FwdConfig,
+        ::Const{typeof(Base.wait)},
+        ::Type{<:Const},
+        t::Duplicated{<:Task},
+    )
     Base.wait(t.val)
     Base.wait(t.dval)
     return nothing
 end
 
 function EnzymeRules.forward(
-    config::EnzymeRules.FwdConfig,
-    ::Const{typeof(Base.wait)},
-    ::Type{<:Const},
-    t::BatchDuplicated{T, N},
-) where {T<:Task,N}
+        config::EnzymeRules.FwdConfig,
+        ::Const{typeof(Base.wait)},
+        ::Type{<:Const},
+        t::BatchDuplicated{T, N},
+    ) where {T <: Task, N}
     Base.wait(t.val)
     for i in 1:N
         Base.wait(t.dval[i])
@@ -439,32 +439,32 @@ function EnzymeRules.forward(
 end
 
 function EnzymeRules.forward(
-    config::EnzymeRules.FwdConfig,
-    ::Const{typeof(Base.wait)},
-    ::Type{<:Const},
-    t::Const{<:Task},
-)
+        config::EnzymeRules.FwdConfig,
+        ::Const{typeof(Base.wait)},
+        ::Type{<:Const},
+        t::Const{<:Task},
+    )
     Base.wait(t.val)
     return nothing
 end
 
 function EnzymeRules.augmented_primal(
-    config::EnzymeRules.RevConfig,
-    ::Const{typeof(Base.wait)},
-    ::Type{<:Const},
-    t::Annotation{<:Task},
-)
+        config::EnzymeRules.RevConfig,
+        ::Const{typeof(Base.wait)},
+        ::Type{<:Const},
+        t::Annotation{<:Task},
+    )
     Base.wait(t.val)
     return EnzymeRules.AugmentedReturn(nothing, nothing, nothing)
 end
 
 function EnzymeRules.reverse(
-    config::EnzymeRules.RevConfig,
-    ::Const{typeof(Base.wait)},
-    ::Type{<:Const},
-    tape,
-    t::Duplicated{<:Task},
-)
+        config::EnzymeRules.RevConfig,
+        ::Const{typeof(Base.wait)},
+        ::Type{<:Const},
+        tape,
+        t::Duplicated{<:Task},
+    )
     if isdefined(t, :dval)
         try
             if !Base.istaskstarted(t.dval) && !Base.istaskdone(t.dval)
@@ -478,12 +478,12 @@ function EnzymeRules.reverse(
 end
 
 function EnzymeRules.reverse(
-    config::EnzymeRules.RevConfig,
-    ::Const{typeof(Base.wait)},
-    ::Type{<:Const},
-    tape,
-    t::BatchDuplicated{T, N},
-) where {T<:Task,N}
+        config::EnzymeRules.RevConfig,
+        ::Const{typeof(Base.wait)},
+        ::Type{<:Const},
+        tape,
+        t::BatchDuplicated{T, N},
+    ) where {T <: Task, N}
     for i in 1:N
         if isdefined(t.dval, i)
             try
@@ -499,34 +499,34 @@ function EnzymeRules.reverse(
 end
 
 function EnzymeRules.reverse(
-    config::EnzymeRules.RevConfig,
-    ::Const{typeof(Base.wait)},
-    ::Type{<:Const},
-    tape,
-    t::Const{<:Task},
-)
+        config::EnzymeRules.RevConfig,
+        ::Const{typeof(Base.wait)},
+        ::Type{<:Const},
+        tape,
+        t::Const{<:Task},
+    )
     return (nothing,)
 end
 
 # -------------------- Base._wait EnzymeRules --------------------
 
 function EnzymeRules.forward(
-    config::EnzymeRules.FwdConfig,
-    ::Const{typeof(Base._wait)},
-    ::Type{<:Const},
-    t::Duplicated{<:Task},
-)
+        config::EnzymeRules.FwdConfig,
+        ::Const{typeof(Base._wait)},
+        ::Type{<:Const},
+        t::Duplicated{<:Task},
+    )
     Base._wait(t.val)
     Base._wait(t.dval)
     return nothing
 end
 
 function EnzymeRules.forward(
-    config::EnzymeRules.FwdConfig,
-    ::Const{typeof(Base._wait)},
-    ::Type{<:Const},
-    t::BatchDuplicated{T, N},
-) where {T<:Task,N}
+        config::EnzymeRules.FwdConfig,
+        ::Const{typeof(Base._wait)},
+        ::Type{<:Const},
+        t::BatchDuplicated{T, N},
+    ) where {T <: Task, N}
     Base._wait(t.val)
     for i in 1:N
         Base._wait(t.dval[i])
@@ -535,32 +535,32 @@ function EnzymeRules.forward(
 end
 
 function EnzymeRules.forward(
-    config::EnzymeRules.FwdConfig,
-    ::Const{typeof(Base._wait)},
-    ::Type{<:Const},
-    t::Const{<:Task},
-)
+        config::EnzymeRules.FwdConfig,
+        ::Const{typeof(Base._wait)},
+        ::Type{<:Const},
+        t::Const{<:Task},
+    )
     Base._wait(t.val)
     return nothing
 end
 
 function EnzymeRules.augmented_primal(
-    config::EnzymeRules.RevConfig,
-    ::Const{typeof(Base._wait)},
-    ::Type{<:Const},
-    t::Annotation{<:Task},
-)
+        config::EnzymeRules.RevConfig,
+        ::Const{typeof(Base._wait)},
+        ::Type{<:Const},
+        t::Annotation{<:Task},
+    )
     Base._wait(t.val)
     return EnzymeRules.AugmentedReturn(nothing, nothing, nothing)
 end
 
 function EnzymeRules.reverse(
-    config::EnzymeRules.RevConfig,
-    ::Const{typeof(Base._wait)},
-    ::Type{<:Const},
-    tape,
-    t::Duplicated{<:Task},
-)
+        config::EnzymeRules.RevConfig,
+        ::Const{typeof(Base._wait)},
+        ::Type{<:Const},
+        tape,
+        t::Duplicated{<:Task},
+    )
     # the reverse of _wait is to enqueue the shadow
     if isdefined(t, :dval)
         try
@@ -575,12 +575,12 @@ function EnzymeRules.reverse(
 end
 
 function EnzymeRules.reverse(
-    config::EnzymeRules.RevConfig,
-    ::Const{typeof(Base._wait)},
-    ::Type{<:Const},
-    tape,
-    t::BatchDuplicated{T, N},
-) where {T<:Task,N}
+        config::EnzymeRules.RevConfig,
+        ::Const{typeof(Base._wait)},
+        ::Type{<:Const},
+        tape,
+        t::BatchDuplicated{T, N},
+    ) where {T <: Task, N}
     for i in 1:N
         if isdefined(t.dval, i)
             try
@@ -596,20 +596,20 @@ function EnzymeRules.reverse(
 end
 
 function EnzymeRules.reverse(
-    config::EnzymeRules.RevConfig,
-    ::Const{typeof(Base._wait)},
-    ::Type{<:Const},
-    tape,
-    t::Const{<:Task},
-)
+        config::EnzymeRules.RevConfig,
+        ::Const{typeof(Base._wait)},
+        ::Type{<:Const},
+        tape,
+        t::Const{<:Task},
+    )
     return (nothing,)
 end
 
 @inline function _fwd_task_return(config::EnzymeRules.FwdConfig, t::Annotation{<:Task})
     needs_primal = EnzymeRules.needs_primal(config)
     needs_shadow = EnzymeRules.needs_shadow(config)
-    
-    if !needs_shadow
+
+    return if !needs_shadow
         if needs_primal
             return t.val
         else
@@ -631,11 +631,11 @@ end
 # -------------------- Base.schedule EnzymeRules --------------------
 
 function EnzymeRules.forward(
-    config::EnzymeRules.FwdConfig,
-    ::Const{typeof(Base.schedule)},
-    ::Type{RT},
-    t::Duplicated{<:Task},
-) where RT
+        config::EnzymeRules.FwdConfig,
+        ::Const{typeof(Base.schedule)},
+        ::Type{RT},
+        t::Duplicated{<:Task},
+    ) where {RT}
     try
         if !Base.istaskstarted(t.val) && !Base.istaskdone(t.val)
             Base.schedule(t.val)
@@ -655,11 +655,11 @@ function EnzymeRules.forward(
 end
 
 function EnzymeRules.forward(
-    config::EnzymeRules.FwdConfig,
-    ::Const{typeof(Base.schedule)},
-    ::Type{RT},
-    t::BatchDuplicated{T, N},
-) where {RT, T<:Task, N}
+        config::EnzymeRules.FwdConfig,
+        ::Const{typeof(Base.schedule)},
+        ::Type{RT},
+        t::BatchDuplicated{T, N},
+    ) where {RT, T <: Task, N}
     if !Base.istaskstarted(t.val) && !Base.istaskdone(t.val)
         Base.schedule(t.val)
     end
@@ -678,11 +678,11 @@ function EnzymeRules.forward(
 end
 
 function EnzymeRules.forward(
-    config::EnzymeRules.FwdConfig,
-    ::Const{typeof(Base.schedule)},
-    ::Type{RT},
-    t::Const{<:Task},
-) where RT
+        config::EnzymeRules.FwdConfig,
+        ::Const{typeof(Base.schedule)},
+        ::Type{RT},
+        t::Const{<:Task},
+    ) where {RT}
     if !Base.istaskstarted(t.val) && !Base.istaskdone(t.val)
         Base.schedule(t.val)
     end
@@ -690,11 +690,11 @@ function EnzymeRules.forward(
 end
 
 function EnzymeRules.augmented_primal(
-    config::EnzymeRules.RevConfig,
-    ::Const{typeof(Base.schedule)},
-    ::Type{RT},
-    t::Annotation{<:Task},
-) where RT
+        config::EnzymeRules.RevConfig,
+        ::Const{typeof(Base.schedule)},
+        ::Type{RT},
+        t::Annotation{<:Task},
+    ) where {RT}
     try
         if !Base.istaskstarted(t.val) && !Base.istaskdone(t.val)
             Base.schedule(t.val)
@@ -707,12 +707,12 @@ function EnzymeRules.augmented_primal(
 end
 
 function EnzymeRules.reverse(
-    config::EnzymeRules.RevConfig,
-    ::Const{typeof(Base.schedule)},
-    ::Type{RT},
-    tape,
-    t::Duplicated{<:Task},
-) where RT
+        config::EnzymeRules.RevConfig,
+        ::Const{typeof(Base.schedule)},
+        ::Type{RT},
+        tape,
+        t::Duplicated{<:Task},
+    ) where {RT}
     # the reverse of schedule is to wait for the shadow
     if isdefined(t, :dval)
         try
@@ -727,12 +727,12 @@ function EnzymeRules.reverse(
 end
 
 function EnzymeRules.reverse(
-    config::EnzymeRules.RevConfig,
-    ::Const{typeof(Base.schedule)},
-    ::Type{RT},
-    tape,
-    t::BatchDuplicated{T, N},
-) where {RT,T<:Task,N}
+        config::EnzymeRules.RevConfig,
+        ::Const{typeof(Base.schedule)},
+        ::Type{RT},
+        tape,
+        t::BatchDuplicated{T, N},
+    ) where {RT, T <: Task, N}
     for i in 1:N
         if isdefined(t.dval, i)
             try
@@ -748,23 +748,23 @@ function EnzymeRules.reverse(
 end
 
 function EnzymeRules.reverse(
-    config::EnzymeRules.RevConfig,
-    ::Const{typeof(Base.schedule)},
-    ::Type{RT},
-    tape,
-    t::Const{<:Task},
-) where RT
+        config::EnzymeRules.RevConfig,
+        ::Const{typeof(Base.schedule)},
+        ::Type{RT},
+        tape,
+        t::Const{<:Task},
+    ) where {RT}
     return (nothing,)
 end
 
 # -------------------- Base.enq_work EnzymeRules --------------------
 
 function EnzymeRules.forward(
-    config::EnzymeRules.FwdConfig,
-    ::Const{typeof(Base.enq_work)},
-    ::Type{RT},
-    t::Duplicated{<:Task},
-) where RT
+        config::EnzymeRules.FwdConfig,
+        ::Const{typeof(Base.enq_work)},
+        ::Type{RT},
+        t::Duplicated{<:Task},
+    ) where {RT}
     if !Base.istaskstarted(t.val) && !Base.istaskdone(t.val)
         Base.enq_work(t.val)
     end
@@ -781,11 +781,11 @@ function EnzymeRules.forward(
 end
 
 function EnzymeRules.forward(
-    config::EnzymeRules.FwdConfig,
-    ::Const{typeof(Base.enq_work)},
-    ::Type{RT},
-    t::BatchDuplicated{T, N},
-) where {RT, T<:Task, N}
+        config::EnzymeRules.FwdConfig,
+        ::Const{typeof(Base.enq_work)},
+        ::Type{RT},
+        t::BatchDuplicated{T, N},
+    ) where {RT, T <: Task, N}
     if !Base.istaskstarted(t.val) && !Base.istaskdone(t.val)
         Base.enq_work(t.val)
     end
@@ -804,11 +804,11 @@ function EnzymeRules.forward(
 end
 
 function EnzymeRules.forward(
-    config::EnzymeRules.FwdConfig,
-    ::Const{typeof(Base.enq_work)},
-    ::Type{RT},
-    t::Const{<:Task},
-) where RT
+        config::EnzymeRules.FwdConfig,
+        ::Const{typeof(Base.enq_work)},
+        ::Type{RT},
+        t::Const{<:Task},
+    ) where {RT}
     if !Base.istaskstarted(t.val) && !Base.istaskdone(t.val)
         Base.enq_work(t.val)
     end
@@ -816,11 +816,11 @@ function EnzymeRules.forward(
 end
 
 function EnzymeRules.augmented_primal(
-    config::EnzymeRules.RevConfig,
-    ::Const{typeof(Base.enq_work)},
-    ::Type{RT},
-    t::Annotation{<:Task},
-) where RT
+        config::EnzymeRules.RevConfig,
+        ::Const{typeof(Base.enq_work)},
+        ::Type{RT},
+        t::Annotation{<:Task},
+    ) where {RT}
     try
         if !Base.istaskstarted(t.val) && !Base.istaskdone(t.val)
             Base.enq_work(t.val)
@@ -833,12 +833,12 @@ function EnzymeRules.augmented_primal(
 end
 
 function EnzymeRules.reverse(
-    config::EnzymeRules.RevConfig,
-    ::Const{typeof(Base.enq_work)},
-    ::Type{RT},
-    tape,
-    t::Duplicated{<:Task},
-) where RT
+        config::EnzymeRules.RevConfig,
+        ::Const{typeof(Base.enq_work)},
+        ::Type{RT},
+        tape,
+        t::Duplicated{<:Task},
+    ) where {RT}
     if isdefined(t, :dval)
         try
             if !Base.istaskdone(t.dval)
@@ -852,12 +852,12 @@ function EnzymeRules.reverse(
 end
 
 function EnzymeRules.reverse(
-    config::EnzymeRules.RevConfig,
-    ::Const{typeof(Base.enq_work)},
-    ::Type{RT},
-    tape,
-    t::BatchDuplicated{T, N},
-) where {RT,T<:Task,N}
+        config::EnzymeRules.RevConfig,
+        ::Const{typeof(Base.enq_work)},
+        ::Type{RT},
+        tape,
+        t::BatchDuplicated{T, N},
+    ) where {RT, T <: Task, N}
     for i in 1:N
         if isdefined(t.dval, i)
             try
@@ -873,11 +873,11 @@ function EnzymeRules.reverse(
 end
 
 function EnzymeRules.reverse(
-    config::EnzymeRules.RevConfig,
-    ::Const{typeof(Base.enq_work)},
-    ::Type{RT},
-    tape,
-    t::Const{<:Task},
-) where RT
+        config::EnzymeRules.RevConfig,
+        ::Const{typeof(Base.enq_work)},
+        ::Type{RT},
+        tape,
+        t::Const{<:Task},
+    ) where {RT}
     return (nothing,)
 end
diff --git a/test/threads.jl b/test/threads.jl
index fd67bb84..9547a163 100644
--- a/test/threads.jl
+++ b/test/threads.jl
@@ -173,11 +173,11 @@ end
         nothing
     end
 
-    t1 = Task(()->nothing)
-    t2 = Task(()->nothing)
-    t3 = Task(()->nothing)
-    t4 = Task(()->nothing)
-    
+    t1 = Task(() -> nothing)
+    t2 = Task(() -> nothing)
+    t3 = Task(() -> nothing)
+    t4 = Task(() -> nothing)
+
     # Pre-schedule tasks for wait so we don't deadlock
     Base.schedule(t1)
     Base.schedule(t2)
@@ -187,7 +187,7 @@ end
     Base.wait(t2)
     Base.wait(t3)
     Base.wait(t4)
-    
+
     @test Enzyme.autodiff(Reverse, wait_f, Const(t1)) === ()
     @test Enzyme.autodiff(Reverse, wait_f, Duplicated(t1, t2)) === ()
     @test Enzyme.autodiff(Reverse, wait_f, BatchDuplicated(t1, (t2, t3))) === ()
@@ -205,15 +205,15 @@ end
     @test Enzyme.autodiff(Forward, _wait_f, BatchDuplicated(t1, (t2, t3))) === ()
 
 
-    t5 = Task(()->nothing)
-    t6 = Task(()->nothing)
-    t7 = Task(()->nothing)
-    t8 = Task(()->nothing)
+    t5 = Task(() -> nothing)
+    t6 = Task(() -> nothing)
+    t7 = Task(() -> nothing)
+    t8 = Task(() -> nothing)
 
     @test Enzyme.autodiff(Reverse, schedule_f, Const(t5)) === ()
     @test Enzyme.autodiff(Forward, schedule_f, Const(t6)) === ()
 
-    t9 = Task(()->nothing); t10 = Task(()->nothing)
+    t9 = Task(() -> nothing); t10 = Task(() -> nothing)
     @test Enzyme.autodiff(Reverse, enq_work_f, Const(t9)) === ()
     @test Enzyme.autodiff(Forward, enq_work_f, Const(t10)) === ()
 end

@github-actions
Copy link
Copy Markdown
Contributor

Benchmark Results

main 47038e4... main / 47038e4...
basics/make_zero/namedtuple 0.0559 ± 0.0073 μs 0.0553 ± 0.0079 μs 1.01 ± 0.2
basics/make_zero/struct 0.248 ± 0.011 μs 0.249 ± 0.0099 μs 0.994 ± 0.06
basics/overhead 3.53 ± 0.16 ns 3.46 ± 0.001 ns 1.02 ± 0.047
basics/remake_zero!/namedtuple 0.219 ± 0.012 μs 0.222 ± 0.0097 μs 0.985 ± 0.068
basics/remake_zero!/struct 0.223 ± 0.013 μs 0.224 ± 0.009 μs 0.995 ± 0.069
fold_broadcast/multidim_sum_bcast/1D 10.9 ± 0.61 μs 10.8 ± 0.27 μs 1.01 ± 0.061
fold_broadcast/multidim_sum_bcast/2D 12.2 ± 0.34 μs 12.1 ± 0.34 μs 1.01 ± 0.04
time_to_load 1.03 ± 0.014 s 1.05 ± 0.0057 s 0.987 ± 0.015

Benchmark Plots

A plot of the benchmark results has been uploaded as an artifact at https://github.com/EnzymeAD/Enzyme.jl/actions/runs/23175412988/artifacts/5956824796.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant