Conversation
Contributor
|
Your PR requires formatting changes to meet the project's style guidelines. Click here to view the suggested changes.diff --git a/src/internal_rules/core.jl b/src/internal_rules/core.jl
index ccaee4bc..03c5f78e 100644
--- a/src/internal_rules/core.jl
+++ b/src/internal_rules/core.jl
@@ -415,22 +415,22 @@ end
# -------------------- Base.wait EnzymeRules --------------------
function EnzymeRules.forward(
- config::EnzymeRules.FwdConfig,
- ::Const{typeof(Base.wait)},
- ::Type{<:Const},
- t::Duplicated{<:Task},
-)
+ config::EnzymeRules.FwdConfig,
+ ::Const{typeof(Base.wait)},
+ ::Type{<:Const},
+ t::Duplicated{<:Task},
+ )
Base.wait(t.val)
Base.wait(t.dval)
return nothing
end
function EnzymeRules.forward(
- config::EnzymeRules.FwdConfig,
- ::Const{typeof(Base.wait)},
- ::Type{<:Const},
- t::BatchDuplicated{T, N},
-) where {T<:Task,N}
+ config::EnzymeRules.FwdConfig,
+ ::Const{typeof(Base.wait)},
+ ::Type{<:Const},
+ t::BatchDuplicated{T, N},
+ ) where {T <: Task, N}
Base.wait(t.val)
for i in 1:N
Base.wait(t.dval[i])
@@ -439,32 +439,32 @@ function EnzymeRules.forward(
end
function EnzymeRules.forward(
- config::EnzymeRules.FwdConfig,
- ::Const{typeof(Base.wait)},
- ::Type{<:Const},
- t::Const{<:Task},
-)
+ config::EnzymeRules.FwdConfig,
+ ::Const{typeof(Base.wait)},
+ ::Type{<:Const},
+ t::Const{<:Task},
+ )
Base.wait(t.val)
return nothing
end
function EnzymeRules.augmented_primal(
- config::EnzymeRules.RevConfig,
- ::Const{typeof(Base.wait)},
- ::Type{<:Const},
- t::Annotation{<:Task},
-)
+ config::EnzymeRules.RevConfig,
+ ::Const{typeof(Base.wait)},
+ ::Type{<:Const},
+ t::Annotation{<:Task},
+ )
Base.wait(t.val)
return EnzymeRules.AugmentedReturn(nothing, nothing, nothing)
end
function EnzymeRules.reverse(
- config::EnzymeRules.RevConfig,
- ::Const{typeof(Base.wait)},
- ::Type{<:Const},
- tape,
- t::Duplicated{<:Task},
-)
+ config::EnzymeRules.RevConfig,
+ ::Const{typeof(Base.wait)},
+ ::Type{<:Const},
+ tape,
+ t::Duplicated{<:Task},
+ )
if isdefined(t, :dval)
try
if !Base.istaskstarted(t.dval) && !Base.istaskdone(t.dval)
@@ -478,12 +478,12 @@ function EnzymeRules.reverse(
end
function EnzymeRules.reverse(
- config::EnzymeRules.RevConfig,
- ::Const{typeof(Base.wait)},
- ::Type{<:Const},
- tape,
- t::BatchDuplicated{T, N},
-) where {T<:Task,N}
+ config::EnzymeRules.RevConfig,
+ ::Const{typeof(Base.wait)},
+ ::Type{<:Const},
+ tape,
+ t::BatchDuplicated{T, N},
+ ) where {T <: Task, N}
for i in 1:N
if isdefined(t.dval, i)
try
@@ -499,34 +499,34 @@ function EnzymeRules.reverse(
end
function EnzymeRules.reverse(
- config::EnzymeRules.RevConfig,
- ::Const{typeof(Base.wait)},
- ::Type{<:Const},
- tape,
- t::Const{<:Task},
-)
+ config::EnzymeRules.RevConfig,
+ ::Const{typeof(Base.wait)},
+ ::Type{<:Const},
+ tape,
+ t::Const{<:Task},
+ )
return (nothing,)
end
# -------------------- Base._wait EnzymeRules --------------------
function EnzymeRules.forward(
- config::EnzymeRules.FwdConfig,
- ::Const{typeof(Base._wait)},
- ::Type{<:Const},
- t::Duplicated{<:Task},
-)
+ config::EnzymeRules.FwdConfig,
+ ::Const{typeof(Base._wait)},
+ ::Type{<:Const},
+ t::Duplicated{<:Task},
+ )
Base._wait(t.val)
Base._wait(t.dval)
return nothing
end
function EnzymeRules.forward(
- config::EnzymeRules.FwdConfig,
- ::Const{typeof(Base._wait)},
- ::Type{<:Const},
- t::BatchDuplicated{T, N},
-) where {T<:Task,N}
+ config::EnzymeRules.FwdConfig,
+ ::Const{typeof(Base._wait)},
+ ::Type{<:Const},
+ t::BatchDuplicated{T, N},
+ ) where {T <: Task, N}
Base._wait(t.val)
for i in 1:N
Base._wait(t.dval[i])
@@ -535,32 +535,32 @@ function EnzymeRules.forward(
end
function EnzymeRules.forward(
- config::EnzymeRules.FwdConfig,
- ::Const{typeof(Base._wait)},
- ::Type{<:Const},
- t::Const{<:Task},
-)
+ config::EnzymeRules.FwdConfig,
+ ::Const{typeof(Base._wait)},
+ ::Type{<:Const},
+ t::Const{<:Task},
+ )
Base._wait(t.val)
return nothing
end
function EnzymeRules.augmented_primal(
- config::EnzymeRules.RevConfig,
- ::Const{typeof(Base._wait)},
- ::Type{<:Const},
- t::Annotation{<:Task},
-)
+ config::EnzymeRules.RevConfig,
+ ::Const{typeof(Base._wait)},
+ ::Type{<:Const},
+ t::Annotation{<:Task},
+ )
Base._wait(t.val)
return EnzymeRules.AugmentedReturn(nothing, nothing, nothing)
end
function EnzymeRules.reverse(
- config::EnzymeRules.RevConfig,
- ::Const{typeof(Base._wait)},
- ::Type{<:Const},
- tape,
- t::Duplicated{<:Task},
-)
+ config::EnzymeRules.RevConfig,
+ ::Const{typeof(Base._wait)},
+ ::Type{<:Const},
+ tape,
+ t::Duplicated{<:Task},
+ )
# the reverse of _wait is to enqueue the shadow
if isdefined(t, :dval)
try
@@ -575,12 +575,12 @@ function EnzymeRules.reverse(
end
function EnzymeRules.reverse(
- config::EnzymeRules.RevConfig,
- ::Const{typeof(Base._wait)},
- ::Type{<:Const},
- tape,
- t::BatchDuplicated{T, N},
-) where {T<:Task,N}
+ config::EnzymeRules.RevConfig,
+ ::Const{typeof(Base._wait)},
+ ::Type{<:Const},
+ tape,
+ t::BatchDuplicated{T, N},
+ ) where {T <: Task, N}
for i in 1:N
if isdefined(t.dval, i)
try
@@ -596,20 +596,20 @@ function EnzymeRules.reverse(
end
function EnzymeRules.reverse(
- config::EnzymeRules.RevConfig,
- ::Const{typeof(Base._wait)},
- ::Type{<:Const},
- tape,
- t::Const{<:Task},
-)
+ config::EnzymeRules.RevConfig,
+ ::Const{typeof(Base._wait)},
+ ::Type{<:Const},
+ tape,
+ t::Const{<:Task},
+ )
return (nothing,)
end
@inline function _fwd_task_return(config::EnzymeRules.FwdConfig, t::Annotation{<:Task})
needs_primal = EnzymeRules.needs_primal(config)
needs_shadow = EnzymeRules.needs_shadow(config)
-
- if !needs_shadow
+
+ return if !needs_shadow
if needs_primal
return t.val
else
@@ -631,11 +631,11 @@ end
# -------------------- Base.schedule EnzymeRules --------------------
function EnzymeRules.forward(
- config::EnzymeRules.FwdConfig,
- ::Const{typeof(Base.schedule)},
- ::Type{RT},
- t::Duplicated{<:Task},
-) where RT
+ config::EnzymeRules.FwdConfig,
+ ::Const{typeof(Base.schedule)},
+ ::Type{RT},
+ t::Duplicated{<:Task},
+ ) where {RT}
try
if !Base.istaskstarted(t.val) && !Base.istaskdone(t.val)
Base.schedule(t.val)
@@ -655,11 +655,11 @@ function EnzymeRules.forward(
end
function EnzymeRules.forward(
- config::EnzymeRules.FwdConfig,
- ::Const{typeof(Base.schedule)},
- ::Type{RT},
- t::BatchDuplicated{T, N},
-) where {RT, T<:Task, N}
+ config::EnzymeRules.FwdConfig,
+ ::Const{typeof(Base.schedule)},
+ ::Type{RT},
+ t::BatchDuplicated{T, N},
+ ) where {RT, T <: Task, N}
if !Base.istaskstarted(t.val) && !Base.istaskdone(t.val)
Base.schedule(t.val)
end
@@ -678,11 +678,11 @@ function EnzymeRules.forward(
end
function EnzymeRules.forward(
- config::EnzymeRules.FwdConfig,
- ::Const{typeof(Base.schedule)},
- ::Type{RT},
- t::Const{<:Task},
-) where RT
+ config::EnzymeRules.FwdConfig,
+ ::Const{typeof(Base.schedule)},
+ ::Type{RT},
+ t::Const{<:Task},
+ ) where {RT}
if !Base.istaskstarted(t.val) && !Base.istaskdone(t.val)
Base.schedule(t.val)
end
@@ -690,11 +690,11 @@ function EnzymeRules.forward(
end
function EnzymeRules.augmented_primal(
- config::EnzymeRules.RevConfig,
- ::Const{typeof(Base.schedule)},
- ::Type{RT},
- t::Annotation{<:Task},
-) where RT
+ config::EnzymeRules.RevConfig,
+ ::Const{typeof(Base.schedule)},
+ ::Type{RT},
+ t::Annotation{<:Task},
+ ) where {RT}
try
if !Base.istaskstarted(t.val) && !Base.istaskdone(t.val)
Base.schedule(t.val)
@@ -707,12 +707,12 @@ function EnzymeRules.augmented_primal(
end
function EnzymeRules.reverse(
- config::EnzymeRules.RevConfig,
- ::Const{typeof(Base.schedule)},
- ::Type{RT},
- tape,
- t::Duplicated{<:Task},
-) where RT
+ config::EnzymeRules.RevConfig,
+ ::Const{typeof(Base.schedule)},
+ ::Type{RT},
+ tape,
+ t::Duplicated{<:Task},
+ ) where {RT}
# the reverse of schedule is to wait for the shadow
if isdefined(t, :dval)
try
@@ -727,12 +727,12 @@ function EnzymeRules.reverse(
end
function EnzymeRules.reverse(
- config::EnzymeRules.RevConfig,
- ::Const{typeof(Base.schedule)},
- ::Type{RT},
- tape,
- t::BatchDuplicated{T, N},
-) where {RT,T<:Task,N}
+ config::EnzymeRules.RevConfig,
+ ::Const{typeof(Base.schedule)},
+ ::Type{RT},
+ tape,
+ t::BatchDuplicated{T, N},
+ ) where {RT, T <: Task, N}
for i in 1:N
if isdefined(t.dval, i)
try
@@ -748,23 +748,23 @@ function EnzymeRules.reverse(
end
function EnzymeRules.reverse(
- config::EnzymeRules.RevConfig,
- ::Const{typeof(Base.schedule)},
- ::Type{RT},
- tape,
- t::Const{<:Task},
-) where RT
+ config::EnzymeRules.RevConfig,
+ ::Const{typeof(Base.schedule)},
+ ::Type{RT},
+ tape,
+ t::Const{<:Task},
+ ) where {RT}
return (nothing,)
end
# -------------------- Base.enq_work EnzymeRules --------------------
function EnzymeRules.forward(
- config::EnzymeRules.FwdConfig,
- ::Const{typeof(Base.enq_work)},
- ::Type{RT},
- t::Duplicated{<:Task},
-) where RT
+ config::EnzymeRules.FwdConfig,
+ ::Const{typeof(Base.enq_work)},
+ ::Type{RT},
+ t::Duplicated{<:Task},
+ ) where {RT}
if !Base.istaskstarted(t.val) && !Base.istaskdone(t.val)
Base.enq_work(t.val)
end
@@ -781,11 +781,11 @@ function EnzymeRules.forward(
end
function EnzymeRules.forward(
- config::EnzymeRules.FwdConfig,
- ::Const{typeof(Base.enq_work)},
- ::Type{RT},
- t::BatchDuplicated{T, N},
-) where {RT, T<:Task, N}
+ config::EnzymeRules.FwdConfig,
+ ::Const{typeof(Base.enq_work)},
+ ::Type{RT},
+ t::BatchDuplicated{T, N},
+ ) where {RT, T <: Task, N}
if !Base.istaskstarted(t.val) && !Base.istaskdone(t.val)
Base.enq_work(t.val)
end
@@ -804,11 +804,11 @@ function EnzymeRules.forward(
end
function EnzymeRules.forward(
- config::EnzymeRules.FwdConfig,
- ::Const{typeof(Base.enq_work)},
- ::Type{RT},
- t::Const{<:Task},
-) where RT
+ config::EnzymeRules.FwdConfig,
+ ::Const{typeof(Base.enq_work)},
+ ::Type{RT},
+ t::Const{<:Task},
+ ) where {RT}
if !Base.istaskstarted(t.val) && !Base.istaskdone(t.val)
Base.enq_work(t.val)
end
@@ -816,11 +816,11 @@ function EnzymeRules.forward(
end
function EnzymeRules.augmented_primal(
- config::EnzymeRules.RevConfig,
- ::Const{typeof(Base.enq_work)},
- ::Type{RT},
- t::Annotation{<:Task},
-) where RT
+ config::EnzymeRules.RevConfig,
+ ::Const{typeof(Base.enq_work)},
+ ::Type{RT},
+ t::Annotation{<:Task},
+ ) where {RT}
try
if !Base.istaskstarted(t.val) && !Base.istaskdone(t.val)
Base.enq_work(t.val)
@@ -833,12 +833,12 @@ function EnzymeRules.augmented_primal(
end
function EnzymeRules.reverse(
- config::EnzymeRules.RevConfig,
- ::Const{typeof(Base.enq_work)},
- ::Type{RT},
- tape,
- t::Duplicated{<:Task},
-) where RT
+ config::EnzymeRules.RevConfig,
+ ::Const{typeof(Base.enq_work)},
+ ::Type{RT},
+ tape,
+ t::Duplicated{<:Task},
+ ) where {RT}
if isdefined(t, :dval)
try
if !Base.istaskdone(t.dval)
@@ -852,12 +852,12 @@ function EnzymeRules.reverse(
end
function EnzymeRules.reverse(
- config::EnzymeRules.RevConfig,
- ::Const{typeof(Base.enq_work)},
- ::Type{RT},
- tape,
- t::BatchDuplicated{T, N},
-) where {RT,T<:Task,N}
+ config::EnzymeRules.RevConfig,
+ ::Const{typeof(Base.enq_work)},
+ ::Type{RT},
+ tape,
+ t::BatchDuplicated{T, N},
+ ) where {RT, T <: Task, N}
for i in 1:N
if isdefined(t.dval, i)
try
@@ -873,11 +873,11 @@ function EnzymeRules.reverse(
end
function EnzymeRules.reverse(
- config::EnzymeRules.RevConfig,
- ::Const{typeof(Base.enq_work)},
- ::Type{RT},
- tape,
- t::Const{<:Task},
-) where RT
+ config::EnzymeRules.RevConfig,
+ ::Const{typeof(Base.enq_work)},
+ ::Type{RT},
+ tape,
+ t::Const{<:Task},
+ ) where {RT}
return (nothing,)
end
diff --git a/test/threads.jl b/test/threads.jl
index fd67bb84..9547a163 100644
--- a/test/threads.jl
+++ b/test/threads.jl
@@ -173,11 +173,11 @@ end
nothing
end
- t1 = Task(()->nothing)
- t2 = Task(()->nothing)
- t3 = Task(()->nothing)
- t4 = Task(()->nothing)
-
+ t1 = Task(() -> nothing)
+ t2 = Task(() -> nothing)
+ t3 = Task(() -> nothing)
+ t4 = Task(() -> nothing)
+
# Pre-schedule tasks for wait so we don't deadlock
Base.schedule(t1)
Base.schedule(t2)
@@ -187,7 +187,7 @@ end
Base.wait(t2)
Base.wait(t3)
Base.wait(t4)
-
+
@test Enzyme.autodiff(Reverse, wait_f, Const(t1)) === ()
@test Enzyme.autodiff(Reverse, wait_f, Duplicated(t1, t2)) === ()
@test Enzyme.autodiff(Reverse, wait_f, BatchDuplicated(t1, (t2, t3))) === ()
@@ -205,15 +205,15 @@ end
@test Enzyme.autodiff(Forward, _wait_f, BatchDuplicated(t1, (t2, t3))) === ()
- t5 = Task(()->nothing)
- t6 = Task(()->nothing)
- t7 = Task(()->nothing)
- t8 = Task(()->nothing)
+ t5 = Task(() -> nothing)
+ t6 = Task(() -> nothing)
+ t7 = Task(() -> nothing)
+ t8 = Task(() -> nothing)
@test Enzyme.autodiff(Reverse, schedule_f, Const(t5)) === ()
@test Enzyme.autodiff(Forward, schedule_f, Const(t6)) === ()
- t9 = Task(()->nothing); t10 = Task(()->nothing)
+ t9 = Task(() -> nothing); t10 = Task(() -> nothing)
@test Enzyme.autodiff(Reverse, enq_work_f, Const(t9)) === ()
@test Enzyme.autodiff(Forward, enq_work_f, Const(t10)) === ()
end |
Contributor
Benchmark Results
Benchmark PlotsA plot of the benchmark results has been uploaded as an artifact at https://github.com/EnzymeAD/Enzyme.jl/actions/runs/23175412988/artifacts/5956824796. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
No description provided.