Skip to content

Add rules for jl_eqtable_pop and jl_eqtable_nextind (#3042)#3043

Open
vchuravy wants to merge 7 commits intomainfrom
vc/eqtable_rules
Open

Add rules for jl_eqtable_pop and jl_eqtable_nextind (#3042)#3043
vchuravy wants to merge 7 commits intomainfrom
vc/eqtable_rules

Conversation

@vchuravy
Copy link
Copy Markdown
Member

Implement augfwd, fwd, and rev rules for jl_eqtable_pop mimicking eqtableput
Add argument activity annotations for jl_eqtable_pop
Mark jl_eqtable_nextind as pure read-only and inactive
Add both functions to validation whitelist
Add GPUArraysCore to test dependencies and include an @allowscalar test in iddict.jl

Fixes #3042

Co-authored-by: Gemini 3.1 Pro (High) gemini@google.com

Implement augfwd, fwd, and rev rules for jl_eqtable_pop mimicking eqtableput
Add argument activity annotations for jl_eqtable_pop
Mark jl_eqtable_nextind as pure read-only and inactive
Add both functions to validation whitelist
Add GPUArraysCore to test dependencies and include an @allowscalar test in iddict.jl

Fixes #3042

Co-authored-by: Gemini 3.1 Pro (High) <gemini@google.com>
@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Apr 23, 2026

Your PR requires formatting changes to meet the project's style guidelines.
Please consider running Runic (git runic main) to apply these changes.

Click here to view the suggested changes.
diff --git a/src/compiler/validation.jl b/src/compiler/validation.jl
index a4704e18..69394911 100644
--- a/src/compiler/validation.jl
+++ b/src/compiler/validation.jl
@@ -89,8 +89,8 @@ function __init__()
         "ijl_eqtable_get",
         "jl_eqtable_put",
         "ijl_eqtable_put",
-        "jl_eqtable_pop",
-        "ijl_eqtable_pop",
+            "jl_eqtable_pop",
+            "ijl_eqtable_pop",
         "memcmp",
         "memchr",
         "jl_get_nth_field_checked",
diff --git a/src/llvm/attributes.jl b/src/llvm/attributes.jl
index 03c0a0ac..4b1bd0c4 100644
--- a/src/llvm/attributes.jl
+++ b/src/llvm/attributes.jl
@@ -210,10 +210,10 @@ const nofreefns = Set{String}((
     "cuCtxGetId",
     "cuDeviceGetName",
     "ijl_eqtable_get",
-    "jl_eqtable_pop",
-    "ijl_eqtable_pop",
-    "jl_eqtable_nextind",
-    "ijl_eqtable_nextind",
+        "jl_eqtable_pop",
+        "ijl_eqtable_pop",
+        "jl_eqtable_nextind",
+        "ijl_eqtable_nextind",
     "cuCtxGetApiVersion",
     "cuCtxSetCurrent",
 ))
@@ -344,8 +344,8 @@ const inactivefns = Set{String}((
     "jl_array_to_string",
     "ijl_array_to_string",
     "pcre2_jit_compile_8",
-    "jl_eqtable_nextind",
-    "ijl_eqtable_nextind",
+        "jl_eqtable_nextind",
+        "ijl_eqtable_nextind",
     # "jl_"
 ))
 
@@ -762,8 +762,8 @@ function annotate!(mod::LLVM.Module)
         "jl_reshape_array",
         "ijl_eqtable_get",
         "jl_eqtable_get",
-        "jl_eqtable_pop",
-        "ijl_eqtable_pop",
+            "jl_eqtable_pop",
+            "ijl_eqtable_pop",
         "jl_gc_run_pending_finalizers",
         "ijl_try_substrtod",
         "jl_try_substrtod",
@@ -1128,8 +1128,8 @@ function annotate!(mod::LLVM.Module)
                             "memory",
                             MemoryEffect(
                                 (MRI_ModRef << getLocationPos(ArgMem)) |
-                                (MRI_NoModRef << getLocationPos(InaccessibleMem)) |
-                                (MRI_NoModRef << getLocationPos(Other)),
+                                    (MRI_NoModRef << getLocationPos(InaccessibleMem)) |
+                                    (MRI_NoModRef << getLocationPos(Other)),
                             ).data,
                         ),
                     )
@@ -1153,8 +1153,8 @@ function annotate!(mod::LLVM.Module)
                             "memory",
                             MemoryEffect(
                                 (MRI_Ref << getLocationPos(ArgMem)) |
-                                (MRI_NoModRef << getLocationPos(InaccessibleMem)) |
-                                (MRI_NoModRef << getLocationPos(Other)),
+                                    (MRI_NoModRef << getLocationPos(InaccessibleMem)) |
+                                    (MRI_NoModRef << getLocationPos(Other)),
                             ).data,
                         ),
                     )
diff --git a/src/rules/llvmrules.jl b/src/rules/llvmrules.jl
index 11980cce..bcd451fa 100644
--- a/src/rules/llvmrules.jl
+++ b/src/rules/llvmrules.jl
@@ -1166,25 +1166,25 @@ end
             B,
             orig,
             "Enzyme: Not yet implemented constant table in jl_eqtable_get " *
-            string(origh) *
-            " " *
-            string(orig) *
-            " result: " *
-            string(absint(orig)) *
-            " " *
-            string(abs_typeof(orig, true)) *
-            " dict: " *
-            string(absint(origh)) *
-            " " *
-            string(abs_typeof(origh, true)) *
-            " key " *
-            string(absint(origkey)) *
-            " " *
-            string(abs_typeof(origkey, true)) *
-            " dflt " *
-            string(absint(origdflt)) *
-            " " *
-            string(abs_typeof(origdflt, true)),
+                string(origh) *
+                " " *
+                string(orig) *
+                " result: " *
+                string(absint(orig)) *
+                " " *
+                string(abs_typeof(orig, true)) *
+                " dict: " *
+                string(absint(origh)) *
+                " " *
+                string(abs_typeof(origh, true)) *
+                " key " *
+                string(absint(origkey)) *
+                " " *
+                string(abs_typeof(origkey, true)) *
+                " dflt " *
+                string(absint(origdflt)) *
+                " " *
+                string(abs_typeof(origdflt, true)),
         )
         return false
     end
@@ -1196,9 +1196,9 @@ end
             Base.unsafe_convert(
                 Cstring,
                 "Mixed activity for default of jl_eqtable_get " *
-                string(orig) *
-                " " *
-                string(origdflt),
+                    string(orig) *
+                    " " *
+                    string(origdflt),
             ),
             orig.ref,
             API.ET_MixedActivityError,
@@ -1215,7 +1215,7 @@ end
             else
                 ST = LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(nop)))
                 shadowm = LLVM.UndefValue(ST)
-                for j = 1:width
+                for j in 1:width
                     shadowm = insert_value!(B, shadowm, nop, j - 1)
                 end
                 shadowm
@@ -1387,9 +1387,9 @@ end
             Base.unsafe_convert(
                 Cstring,
                 "Mixed activity for val of jl_eqtable_put " *
-                string(orig) *
-                " " *
-                string(origval),
+                    string(orig) *
+                    " " *
+                    string(origval),
             ),
             orig.ref,
             API.ET_MixedActivityError,
@@ -1406,7 +1406,7 @@ end
             else
                 ST = LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(nop)))
                 shadowm = LLVM.UndefValue(ST)
-                for j = 1:width
+                for j in 1:width
                     shadowm = insert_value!(B, shadowm, nop, j - 1)
                 end
                 shadowm
@@ -1419,10 +1419,10 @@ end
     newvals = API.CValueType[API.VT_Shadow, API.VT_Primal, API.VT_Shadow, API.VT_None]
 
     newops = LLVM.Value[
-      shadowh,
-      new_from_original(gutils, origkey),
-      shadowval,
-      LLVM.null(value_type(originserted)),
+        shadowh,
+        new_from_original(gutils, origkey),
+        shadowval,
+        LLVM.null(value_type(originserted)),
     ]
 
     shadowres = batch_call_same_with_inverted_arg_if_active!(B, gutils, orig, newops, newvals, false) #=lookup=#
@@ -1520,9 +1520,9 @@ end
             Base.unsafe_convert(
                 Cstring,
                 "Mixed activity for default of jl_eqtable_pop " *
-                string(orig) *
-                " " *
-                string(origdflt),
+                    string(orig) *
+                    " " *
+                    string(origdflt),
             ),
             orig.ref,
             API.ET_MixedActivityError,
@@ -1539,7 +1539,7 @@ end
             else
                 ST = LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(nop)))
                 shadowm = LLVM.UndefValue(ST)
-                for j = 1:width
+                for j in 1:width
                     shadowm = insert_value!(B, shadowm, nop, j - 1)
                 end
                 shadowm
@@ -1552,10 +1552,10 @@ end
     newvals = API.CValueType[API.VT_Shadow, API.VT_Primal, API.VT_Shadow, API.VT_None]
 
     newops = LLVM.Value[
-      shadowh,
-      new_from_original(gutils, origkey),
-      shadowdflt,
-      LLVM.null(value_type(origfound)),
+        shadowh,
+        new_from_original(gutils, origkey),
+        shadowdflt,
+        LLVM.null(value_type(origfound)),
     ]
 
     shadowres = batch_call_same_with_inverted_arg_if_active!(B, gutils, orig, newops, newvals, false) #=lookup=#
@@ -1584,9 +1584,9 @@ end
             Base.unsafe_convert(
                 Cstring,
                 "Mixed activity for default of jl_eqtable_pop " *
-                string(orig) *
-                " " *
-                string(origdflt),
+                    string(orig) *
+                    " " *
+                    string(origdflt),
             ),
             orig.ref,
             API.ET_MixedActivityError,
@@ -1603,7 +1603,7 @@ end
             else
                 ST = LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(nop)))
                 shadowm = LLVM.UndefValue(ST)
-                for j = 1:width
+                for j in 1:width
                     shadowm = insert_value!(B, shadowm, nop, j - 1)
                 end
                 shadowm
@@ -1616,10 +1616,10 @@ end
     newvals = API.CValueType[API.VT_Shadow, API.VT_Primal, API.VT_Shadow, API.VT_None]
 
     newops = LLVM.Value[
-      shadowh,
-      new_from_original(gutils, origkey),
-      shadowdflt,
-      LLVM.null(value_type(origfound)),
+        shadowh,
+        new_from_original(gutils, origkey),
+        shadowdflt,
+        LLVM.null(value_type(origfound)),
     ]
 
     shadowres = batch_call_same_with_inverted_arg_if_active!(B, gutils, orig, newops, newvals, false) #=lookup=#

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Apr 23, 2026

Benchmark Results

main 5aa1529... main / 5aa1529...
basics/make_zero/namedtuple 0.0539 ± 0.002 μs 0.0534 ± 0.002 μs 1.01 ± 0.053
basics/make_zero/struct 0.275 ± 0.0059 μs 0.288 ± 0.0073 μs 0.955 ± 0.032
basics/overhead 4.03 ± 0.01 ns 4.95 ± 0.92 ns 0.814 ± 0.15
basics/remake_zero!/namedtuple 0.223 ± 0.0088 μs 0.225 ± 0.01 μs 0.994 ± 0.059
basics/remake_zero!/struct 0.228 ± 0.01 μs 0.229 ± 0.0081 μs 0.996 ± 0.057
fold_broadcast/multidim_sum_bcast/1D 10.3 ± 0.39 μs 10.3 ± 1.8 μs 0.993 ± 0.17
fold_broadcast/multidim_sum_bcast/2D 12.2 ± 0.29 μs 10.3 ± 0.2 μs 1.18 ± 0.036
time_to_load 1.08 ± 0.022 s 1.1 ± 0.015 s 0.987 ± 0.024

Benchmark Plots

A plot of the benchmark results has been uploaded as an artifact at https://github.com/EnzymeAD/Enzyme.jl/actions/runs/25425001320/artifacts/6826042797.

vchuravy and others added 5 commits April 23, 2026 19:32
Properly compute shadow values for jl_eqtable_get, jl_eqtable_put, and jl_eqtable_pop in forward mode, conditionally on shadowR.

Co-authored-by: Gemini 3.1 Pro (High) <gemini@google.com>
Since IdDict stores boxed values, which are represented as pointers in Enzyme, we do not need custom reverse-mode tracking. Gradients naturally accumulate into the shadow boxes.

Added tests for storing and reading active arrays and floats in IdDict with inactive keys.

Co-authored-by: Gemini 3.1 Pro (High) <gemini@google.com>
Since IdDict maps seamlessly to its shadow representation, Duplicated(IdDict(), IdDict()) works naturally with our implementation.

Co-authored-by: Gemini 3.1 Pro (High) <gemini@google.com>
Added tests for Enzyme.make_zero, Enzyme.make_zero!, and Enzyme.remake_zero! on IdDict. Also tested Enzyme.jacobian in both Forward and Reverse modes with an IdDict as input by passing explicit shadows.

Co-authored-by: Gemini 3.1 Pro (High) <gemini@google.com>
Added a test that populates an IdDict with an active Float64, an active Vector{Float64}, and an inactive Int. This verifies that Enzyme correctly accumulates gradients when interacting with dynamically-typed boxes and respects mixed activity in runtime structures.

Co-authored-by: Gemini 3.1 Pro (High) <gemini@google.com>
@vchuravy
Copy link
Copy Markdown
Member Author

The x86 error is JuliaLLVM/LLVM.jl#548

@vchuravy vchuravy requested a review from wsmoses April 26, 2026 18:21
Comment thread src/rules/llvmrules.jl
if shadowR != C_NULL
unsafe_store!(shadowR, shadowres.ref)
end
return false
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure this is correct, if the popped value guaranteed to be the same jlvalue within the list?

Exercises cases left uncovered by the existing tests, where pop!'s
return value is used in computation (existing tests only call pop!
for side effects via @allowscalar/delete!):

- pop! return value used in forward and reverse mode
- pop! with non-trivial stored expression (shadow must match put!, not default)
- pop! with key present but explicit default (default shadow must not leak)
- split reverse mode: gradient flows into the object captured at forward time,
  not into whatever is currently in the shadow dict (mutation-between-passes test)
- pop! on missing key returns active default with correct gradient
- overwrite same key twice: pop! must carry second put's shadow, not first
- double pop same key: second (missing) pop must not inherit first's shadow
- two distinct keys: per-key shadows must not mix
- pop then put back: shadow table reflects re-inserted transformed value

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

@allowscalar causes error in autodiff (even on CPU)

2 participants