Skip to content

Option to ignore const checks in EnzymeTestUtils#3053

Open
jlperla wants to merge 1 commit intoEnzymeAD:mainfrom
jlperla:enzymetestutils-ignore-const-checks
Open

Option to ignore const checks in EnzymeTestUtils#3053
jlperla wants to merge 1 commit intoEnzymeAD:mainfrom
jlperla:enzymetestutils-ignore-const-checks

Conversation

@jlperla
Copy link
Copy Markdown

@jlperla jlperla commented Apr 28, 2026

Summary

Adds an opt-in ignore_const_checks::Bool=false keyword to EnzymeTestUtils.test_forward and test_reverse. When true, the post-call mutation assertion is skipped for Const-annotated arguments (and for test_reverse, the FD-vs-AD derivative comparison is skipped on Const args too).

Motivation

Custom rules sometimes legitimately scribble on a Const scratch buffer
during AD without that mutation affecting the gradient — e.g. workspace
structs holding LAPACK workplace arrays we do not want/need to shadow. Today the only workaround is to
monkey-patch the helpers or hand-write the finite-differences checks. This gives users a 1-line opt-in instead.

I am not sure if these are the right unit tests to have for this as they are AI generated.

@wsmoses Please let me know if this is the wrong way to submit suggestions.

…d/test_reverse

Adds an opt-in `ignore_const_checks::Bool=false` keyword to both helpers that,
when `true`, skips the post-call mutation assertion on `Const`-annotated
arguments (and, for `test_reverse`, also the FD-vs-AD derivative comparison
on `Const` args). This is useful when a custom rule legitimately scribbles
on a `Const` scratch buffer during AD without affecting the gradient.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@github-actions
Copy link
Copy Markdown
Contributor

Your PR requires formatting changes to meet the project's style guidelines.
Please consider running Runic (git runic main) to apply these changes.

Click here to view the suggested changes.
diff --git a/lib/EnzymeTestUtils/src/test_forward.jl b/lib/EnzymeTestUtils/src/test_forward.jl
index da1e3f4f..883119b7 100644
--- a/lib/EnzymeTestUtils/src/test_forward.jl
+++ b/lib/EnzymeTestUtils/src/test_forward.jl
@@ -64,8 +64,8 @@ function test_forward(
     rtol::Real=1e-9,
     atol::Real=1e-9,
     testset_name=nothing,
-    runtime_activity::Bool=false,
-    ignore_const_checks::Bool=false,
+        runtime_activity::Bool = false,
+        ignore_const_checks::Bool = false,
 )
     call_with_copy = CallWithCopyKWargs(fkwargs)
     call_with_kwargs = CallWithKWargs(fkwargs)
diff --git a/lib/EnzymeTestUtils/src/test_reverse.jl b/lib/EnzymeTestUtils/src/test_reverse.jl
index 06399a6f..dc74de47 100644
--- a/lib/EnzymeTestUtils/src/test_reverse.jl
+++ b/lib/EnzymeTestUtils/src/test_reverse.jl
@@ -80,9 +80,9 @@ function test_reverse(
     atol::Real=1e-9,
     testset_name=nothing,
     runtime_activity::Bool=false,
-    output_tangent=nothing,
-    ignore_const_checks::Bool=false,
-)
+        output_tangent = nothing,
+        ignore_const_checks::Bool = false,
+    )
     call_with_captured_kwargs = CallWithKWargs(fkwargs)
     if testset_name === nothing
         testset_name = "test_reverse: $f with return activity $ret_activity on $(_string_activity(args))"
diff --git a/lib/EnzymeTestUtils/test/test_forward.jl b/lib/EnzymeTestUtils/test/test_forward.jl
index effb0006..e7bd6bb4 100644
--- a/lib/EnzymeTestUtils/test/test_forward.jl
+++ b/lib/EnzymeTestUtils/test/test_forward.jl
@@ -25,12 +25,12 @@ end
 f_const_scratch_fwd(x, scratch) = x .^ 2
 
 function EnzymeRules.forward(
-    config,
-    func::Const{typeof(f_const_scratch_fwd)},
-    RT::Type{<:Union{Const,Duplicated,DuplicatedNoNeed}},
-    x::Union{Const,Duplicated},
-    scratch::Const,
-)
+        config,
+        func::Const{typeof(f_const_scratch_fwd)},
+        RT::Type{<:Union{Const, Duplicated, DuplicatedNoNeed}},
+        x::Union{Const, Duplicated},
+        scratch::Const,
+    )
     scratch.val .= x.val  # AD-only scribble on Const scratch buffer
     if RT <: Const
         return func.val(x.val, scratch.val)
@@ -250,7 +250,7 @@ end
                     Duplicated,
                     (x, Duplicated),
                     (scratch, Const);
-                    ignore_const_checks=true,
+                    ignore_const_checks = true,
                 )
             end
         end
diff --git a/lib/EnzymeTestUtils/test/test_reverse.jl b/lib/EnzymeTestUtils/test/test_reverse.jl
index 59e15704..974884fa 100644
--- a/lib/EnzymeTestUtils/test/test_reverse.jl
+++ b/lib/EnzymeTestUtils/test/test_reverse.jl
@@ -22,12 +22,12 @@ end
 f_const_scratch_rev(x, scratch) = sum(abs2, x)
 
 function EnzymeRules.augmented_primal(
-    config::EnzymeRules.RevConfigWidth{1},
-    func::Const{typeof(f_const_scratch_rev)},
-    RT::Type{<:Union{Const,Active}},
-    x::Union{Const,Duplicated},
-    scratch::Const,
-)
+        config::EnzymeRules.RevConfigWidth{1},
+        func::Const{typeof(f_const_scratch_rev)},
+        RT::Type{<:Union{Const, Active}},
+        x::Union{Const, Duplicated},
+        scratch::Const,
+    )
     scratch.val .= x.val  # AD-only scribble on Const scratch buffer
     primal = EnzymeRules.needs_primal(config) ? func.val(x.val, scratch.val) : nothing
     tape = copy(x.val)
@@ -35,13 +35,13 @@ function EnzymeRules.augmented_primal(
 end
 
 function EnzymeRules.reverse(
-    config::EnzymeRules.RevConfigWidth{1},
-    func::Const{typeof(f_const_scratch_rev)},
-    dret::Union{Active,Type{<:Const}},
-    tape,
-    x::Union{Const,Duplicated},
-    scratch::Const,
-)
+        config::EnzymeRules.RevConfigWidth{1},
+        func::Const{typeof(f_const_scratch_rev)},
+        dret::Union{Active, Type{<:Const}},
+        tape,
+        x::Union{Const, Duplicated},
+        scratch::Const,
+    )
     if !(x isa Const) && dret isa Active
         x.dval .+= 2 .* dret.val .* tape
     end
@@ -278,7 +278,7 @@ end
                 Active,
                 (x, Duplicated),
                 (scratch, Const);
-                ignore_const_checks=true,
+                ignore_const_checks = true,
             )
         end
     end

@codecov
Copy link
Copy Markdown

codecov Bot commented Apr 29, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 69.57%. Comparing base (c384f7f) to head (ac03183).
⚠️ Report is 8 commits behind head on main.

Additional details and impacted files
@@           Coverage Diff           @@
##             main    #3053   +/-   ##
=======================================
  Coverage   69.57%   69.57%           
=======================================
  Files          66       66           
  Lines       21573    21576    +3     
=======================================
+ Hits        15009    15012    +3     
  Misses       6564     6564           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@wsmoses
Copy link
Copy Markdown
Member

wsmoses commented Apr 29, 2026

I see what you're doing here, but I suppose it may be better to have a way to mark some variables as scratch more generally than just all const ones?

@wsmoses wsmoses requested a review from kshyatt April 29, 2026 12:02
@jlperla
Copy link
Copy Markdown
Author

jlperla commented Apr 29, 2026

Yes, that would be far superior. I tried that but didn't go that direction because a new Activity type (or some test-only flag on activity?) seemed too much for just the unit tests. But if you think that something along those lines is the right direction, then I can close this PR and make an issue on the topic. My pirated test_forward and test_reverse solve my problems for now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants