Option to ignore const checks in EnzymeTestUtils#3053
Option to ignore const checks in EnzymeTestUtils#3053jlperla wants to merge 1 commit intoEnzymeAD:mainfrom
Conversation
…d/test_reverse Adds an opt-in `ignore_const_checks::Bool=false` keyword to both helpers that, when `true`, skips the post-call mutation assertion on `Const`-annotated arguments (and, for `test_reverse`, also the FD-vs-AD derivative comparison on `Const` args). This is useful when a custom rule legitimately scribbles on a `Const` scratch buffer during AD without affecting the gradient. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
|
Your PR requires formatting changes to meet the project's style guidelines. Click here to view the suggested changes.diff --git a/lib/EnzymeTestUtils/src/test_forward.jl b/lib/EnzymeTestUtils/src/test_forward.jl
index da1e3f4f..883119b7 100644
--- a/lib/EnzymeTestUtils/src/test_forward.jl
+++ b/lib/EnzymeTestUtils/src/test_forward.jl
@@ -64,8 +64,8 @@ function test_forward(
rtol::Real=1e-9,
atol::Real=1e-9,
testset_name=nothing,
- runtime_activity::Bool=false,
- ignore_const_checks::Bool=false,
+ runtime_activity::Bool = false,
+ ignore_const_checks::Bool = false,
)
call_with_copy = CallWithCopyKWargs(fkwargs)
call_with_kwargs = CallWithKWargs(fkwargs)
diff --git a/lib/EnzymeTestUtils/src/test_reverse.jl b/lib/EnzymeTestUtils/src/test_reverse.jl
index 06399a6f..dc74de47 100644
--- a/lib/EnzymeTestUtils/src/test_reverse.jl
+++ b/lib/EnzymeTestUtils/src/test_reverse.jl
@@ -80,9 +80,9 @@ function test_reverse(
atol::Real=1e-9,
testset_name=nothing,
runtime_activity::Bool=false,
- output_tangent=nothing,
- ignore_const_checks::Bool=false,
-)
+ output_tangent = nothing,
+ ignore_const_checks::Bool = false,
+ )
call_with_captured_kwargs = CallWithKWargs(fkwargs)
if testset_name === nothing
testset_name = "test_reverse: $f with return activity $ret_activity on $(_string_activity(args))"
diff --git a/lib/EnzymeTestUtils/test/test_forward.jl b/lib/EnzymeTestUtils/test/test_forward.jl
index effb0006..e7bd6bb4 100644
--- a/lib/EnzymeTestUtils/test/test_forward.jl
+++ b/lib/EnzymeTestUtils/test/test_forward.jl
@@ -25,12 +25,12 @@ end
f_const_scratch_fwd(x, scratch) = x .^ 2
function EnzymeRules.forward(
- config,
- func::Const{typeof(f_const_scratch_fwd)},
- RT::Type{<:Union{Const,Duplicated,DuplicatedNoNeed}},
- x::Union{Const,Duplicated},
- scratch::Const,
-)
+ config,
+ func::Const{typeof(f_const_scratch_fwd)},
+ RT::Type{<:Union{Const, Duplicated, DuplicatedNoNeed}},
+ x::Union{Const, Duplicated},
+ scratch::Const,
+ )
scratch.val .= x.val # AD-only scribble on Const scratch buffer
if RT <: Const
return func.val(x.val, scratch.val)
@@ -250,7 +250,7 @@ end
Duplicated,
(x, Duplicated),
(scratch, Const);
- ignore_const_checks=true,
+ ignore_const_checks = true,
)
end
end
diff --git a/lib/EnzymeTestUtils/test/test_reverse.jl b/lib/EnzymeTestUtils/test/test_reverse.jl
index 59e15704..974884fa 100644
--- a/lib/EnzymeTestUtils/test/test_reverse.jl
+++ b/lib/EnzymeTestUtils/test/test_reverse.jl
@@ -22,12 +22,12 @@ end
f_const_scratch_rev(x, scratch) = sum(abs2, x)
function EnzymeRules.augmented_primal(
- config::EnzymeRules.RevConfigWidth{1},
- func::Const{typeof(f_const_scratch_rev)},
- RT::Type{<:Union{Const,Active}},
- x::Union{Const,Duplicated},
- scratch::Const,
-)
+ config::EnzymeRules.RevConfigWidth{1},
+ func::Const{typeof(f_const_scratch_rev)},
+ RT::Type{<:Union{Const, Active}},
+ x::Union{Const, Duplicated},
+ scratch::Const,
+ )
scratch.val .= x.val # AD-only scribble on Const scratch buffer
primal = EnzymeRules.needs_primal(config) ? func.val(x.val, scratch.val) : nothing
tape = copy(x.val)
@@ -35,13 +35,13 @@ function EnzymeRules.augmented_primal(
end
function EnzymeRules.reverse(
- config::EnzymeRules.RevConfigWidth{1},
- func::Const{typeof(f_const_scratch_rev)},
- dret::Union{Active,Type{<:Const}},
- tape,
- x::Union{Const,Duplicated},
- scratch::Const,
-)
+ config::EnzymeRules.RevConfigWidth{1},
+ func::Const{typeof(f_const_scratch_rev)},
+ dret::Union{Active, Type{<:Const}},
+ tape,
+ x::Union{Const, Duplicated},
+ scratch::Const,
+ )
if !(x isa Const) && dret isa Active
x.dval .+= 2 .* dret.val .* tape
end
@@ -278,7 +278,7 @@ end
Active,
(x, Duplicated),
(scratch, Const);
- ignore_const_checks=true,
+ ignore_const_checks = true,
)
end
end |
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #3053 +/- ##
=======================================
Coverage 69.57% 69.57%
=======================================
Files 66 66
Lines 21573 21576 +3
=======================================
+ Hits 15009 15012 +3
Misses 6564 6564 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
|
I see what you're doing here, but I suppose it may be better to have a way to mark some variables as scratch more generally than just all const ones? |
|
Yes, that would be far superior. I tried that but didn't go that direction because a new Activity type (or some test-only flag on activity?) seemed too much for just the unit tests. But if you think that something along those lines is the right direction, then I can close this PR and make an issue on the topic. My pirated test_forward and test_reverse solve my problems for now. |
Summary
Adds an opt-in
ignore_const_checks::Bool=falsekeyword toEnzymeTestUtils.test_forwardandtest_reverse. Whentrue, the post-call mutation assertion is skipped forConst-annotated arguments (and fortest_reverse, the FD-vs-AD derivative comparison is skipped onConstargs too).Motivation
Custom rules sometimes legitimately scribble on a
Constscratch bufferduring AD without that mutation affecting the gradient — e.g. workspace
structs holding LAPACK workplace arrays we do not want/need to shadow. Today the only workaround is to
monkey-patch the helpers or hand-write the finite-differences checks. This gives users a 1-line opt-in instead.
I am not sure if these are the right unit tests to have for this as they are AI generated.
@wsmoses Please let me know if this is the wrong way to submit suggestions.