Skip to content

Separated out memory instance rewrite pass#2561

Merged
wsmoses merged 5 commits intomainfrom
mi2
Sep 11, 2025
Merged

Separated out memory instance rewrite pass#2561
wsmoses merged 5 commits intomainfrom
mi2

Conversation

@wsmoses
Copy link
Copy Markdown
Member

@wsmoses wsmoses commented Sep 11, 2025

Separated from #2555

This simplifies the 0-sized memory creation during similar.

The issue from broadcasting stems from the inability to tell the inner pointer from the newly created array doesn't alias with any of the inputs. The complexity of the "check if zero sized and pick a load from this global vs a new alloca" complicates this -- though is not sufficient in isolation.

before:

L30:                                              ; preds = %L17, %top
  %36 = phi i64 [ %14, %top ], [ %9, %L17 ], !enzyme_inactive !0
  %.not = icmp eq i64 %36, 0, !dbg !948
  br i1 %.not, label %L43, label %L45, !dbg !948

L43:                                              ; preds = %L30
  %37 = load atomic {} addrspace(10)*, {} addrspace(10)** inttoptr (i64 4727432960 to {} addrspace(10)**) unordered, align 256, !dbg !957, !tbaa !18, !alias.scope !764, !noalias !765, !nonnull !0, !enzyme_type !783, !enzymejl_byref_BITS_REF !0, !enzymejl_source_type_Memory\7BFloat64\7D !0
  %.not149 = icmp eq {} addrspace(10)* %37, null, !dbg !957
  br i1 %.not149, label %fail, label %L47, !dbg !957

L45:                                              ; preds = %L30
  %38 = call noalias "enzyme_ReadOnlyOrThrow" "enzyme_type"="{[-1]:Pointer, [-1,0]:Integer, [-1,1]:Integer, [-1,2]:Integer, [-1,3]:Integer, [-1,4]:Integer, [-1,5]:Integer, [-1,6]:Integer, [-1,7]:Integer, [-1,8]:Pointer, [-1,8,-1]:Float@double}" {} addrspace(10)* @jl_alloc_genericmemory({} addrspace(10)* noundef addrspacecast ({}* inttoptr (i64 4727432928 to {}*) to {} addrspace(10)*), i64 %36) #22, !dbg !958
  br label %L47, !dbg !958

L47:                                              ; preds = %L45, %L43
  %39 = phi {} addrspace(10)* [ %38, %L45 ], [ %37, %L43 ]

after:

L30:                                              ; preds = %L17, %top
  %36 = phi i64 [ %14, %top ], [ %9, %L17 ], !enzyme_inactive !0
  %.not = icmp eq i64 %36, 0, !dbg !948
  br i1 %.not, label %L47, label %L45, !dbg !948

L45:                                              ; preds = %L30
  %37 = call noalias "enzyme_ReadOnlyOrThrow" "enzyme_type"="{[-1]:Pointer, [-1,0]:Integer, [-1,1]:Integer, [-1,2]:Integer, [-1,3]:Integer, [-1,4]:Integer, [-1,5]:Integer, [-1,6]:Integer, [-1,7]:Integer, [-1,8]:Pointer, [-1,8,-1]:Float@double}" {} addrspace(10)* @jl_alloc_genericmemory({} addrspace(10)* noundef addrspacecast ({}* inttoptr (i64 128612969528928 to {}*) to {} addrspace(10)*), i64 %36) #22, !dbg !957
  br label %L47, !dbg !957

L47:                                              ; preds = %L30, %L45
  %38 = phi {} addrspace(10)* [ %37, %L45 ], [ addrspacecast ({}* inttoptr (i64 128612969529008 to {}*) to {} addrspace(10)*), %L30 ]

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Sep 11, 2025

Your PR requires formatting changes to meet the project's style guidelines.
Please consider running Runic (git runic main) to apply these changes.

Click here to view the suggested changes.
diff --git a/src/absint.jl b/src/absint.jl
index 62817f5..4a9632a 100644
--- a/src/absint.jl
+++ b/src/absint.jl
@@ -268,24 +268,24 @@ function get_base_and_offset(@nospecialize(larg::LLVM.Value); offsetAllowed::Boo
                 larg = operands(larg)[1]
                 continue
             end
-	    if opcode(larg) == LLVM.API.LLVMGetElementPtr && pinst isa LLVM.Instruction
-		    b = LLVM.IRBuilder()
-		    position!(b, pinst)
-		    offty = LLVM.IntType(8 * sizeof(Int))
-		    offset2 = API.EnzymeComputeByteOffsetOfGEP(b, larg, offty)
-		    if isa(offset2, LLVM.ConstantInt)
-			val = convert(Int, offset2)
-			if offsetAllowed || val == 0
-			    offset += val
-			    larg = operands(larg)[1]
-			    continue
-			else
-			    break
-			end
-		    else
-			break
-		    end
-		end
+            if opcode(larg) == LLVM.API.LLVMGetElementPtr && pinst isa LLVM.Instruction
+                b = LLVM.IRBuilder()
+                position!(b, pinst)
+                offty = LLVM.IntType(8 * sizeof(Int))
+                offset2 = API.EnzymeComputeByteOffsetOfGEP(b, larg, offty)
+                if isa(offset2, LLVM.ConstantInt)
+                    val = convert(Int, offset2)
+                    if offsetAllowed || val == 0
+                        offset += val
+                        larg = operands(larg)[1]
+                        continue
+                    else
+                        break
+                    end
+                else
+                    break
+                end
+            end
         end
         if isa(larg, LLVM.BitCastInst) || isa(larg, LLVM.AddrSpaceCastInst) || isa(larg, LLVM.IntToPtrInst)
             larg = operands(larg)[1]
diff --git a/src/errors.jl b/src/errors.jl
index 08c49dc..e917ed9 100644
--- a/src/errors.jl
+++ b/src/errors.jl
@@ -524,29 +524,29 @@ function julia_error(
                 return seen[cur]
             end
 
-@static if VERSION < v"1.11-"
-else   
-	if isa(cur, LLVM.LoadInst) && isa(value_type(cur), LLVM.PointerType) && LLVM.addrspace(value_type(operands(cur)[1])) == Derived
-                    larg, off = get_base_and_offset(operands(cur)[1]; inst=ncur, inttoptr=true)
-		    if isa(larg, LLVM.ConstantInt) && off == sizeof(Int)
-			ptr = reinterpret(Ptr{Cvoid}, convert(UInt, larg))
-			obj = Base.unsafe_pointer_to_objref(ptr)
+            @static if VERSION < v"1.11-"
+            else
+                if isa(cur, LLVM.LoadInst) && isa(value_type(cur), LLVM.PointerType) && LLVM.addrspace(value_type(operands(cur)[1])) == Derived
+                    larg, off = get_base_and_offset(operands(cur)[1]; inst = ncur, inttoptr = true)
+                    if isa(larg, LLVM.ConstantInt) && off == sizeof(Int)
+                        ptr = reinterpret(Ptr{Cvoid}, convert(UInt, larg))
+                        obj = Base.unsafe_pointer_to_objref(ptr)
                         if obj isa Memory && obj == typeof(obj).instance
                             return make_batched(ncur, prevbb)
                         end
-		    end
+                    end
                 end
-	if isa(cur, LLVM.ConstantExpr) && isa(value_type(cur), LLVM.PointerType) && LLVM.addrspace(value_type(cur)) == Derived
-		larg, off = get_base_and_offset(cur; inst=first(instructions(position(prevbb))), inttoptr=true)
-		if isa(larg, LLVM.ConstantInt) && (off == sizeof(Int) || off == 0)
-			ptr = reinterpret(Ptr{Cvoid}, convert(UInt, larg))
-			obj = Base.unsafe_pointer_to_objref(ptr)
+                if isa(cur, LLVM.ConstantExpr) && isa(value_type(cur), LLVM.PointerType) && LLVM.addrspace(value_type(cur)) == Derived
+                    larg, off = get_base_and_offset(cur; inst = first(instructions(position(prevbb))), inttoptr = true)
+                    if isa(larg, LLVM.ConstantInt) && (off == sizeof(Int) || off == 0)
+                        ptr = reinterpret(Ptr{Cvoid}, convert(UInt, larg))
+                        obj = Base.unsafe_pointer_to_objref(ptr)
                         if obj isa Memory && obj == typeof(obj).instance
                             return make_batched(ncur, prevbb)
                         end
-		    end
+                    end
                 end
-end
+            end
 
             legal, TT, byref = abs_typeof(cur, true)
 
diff --git a/src/llvm/transforms.jl b/src/llvm/transforms.jl
index fced4df..e0bab8e 100644
--- a/src/llvm/transforms.jl
+++ b/src/llvm/transforms.jl
@@ -2298,31 +2298,31 @@ function checkNoAssumeFalse(mod::LLVM.Module, shouldshow::Bool = false)
 end
 
 function rewrite_generic_memory!(mod::LLVM.Module)
-@static if VERSION < v"1.11-"
-else    
-    for f in functions(mod), bb in blocks(f)
-    iter = LLVM.API.LLVMGetFirstInstruction(bb)
-    while iter != C_NULL
-        inst = LLVM.Instruction(iter)
-        iter = LLVM.API.LLVMGetNextInstruction(iter)
-        if !isa(inst, LLVM.LoadInst)
-	   continue
-	end
-	
-	if isa(operands(inst)[1], LLVM.ConstantExpr)
+    return @static if VERSION < v"1.11-"
+    else
+        for f in functions(mod), bb in blocks(f)
+            iter = LLVM.API.LLVMGetFirstInstruction(bb)
+            while iter != C_NULL
+                inst = LLVM.Instruction(iter)
+                iter = LLVM.API.LLVMGetNextInstruction(iter)
+                if !isa(inst, LLVM.LoadInst)
+                    continue
+                end
+
+                if isa(operands(inst)[1], LLVM.ConstantExpr)
                     legal2, obj = absint(inst)
                     if legal2 && obj isa Memory && obj == typeof(obj).instance
-			b = LLVM.IRBuilder()
-			position!(b, inst)
-                       replace_uses!(inst, unsafe_to_llvm(b, obj))
-		       LLVM.API.LLVMInstructionEraseFromParent(inst)
-		       continue
-		    end
-		end
-    end
+                        b = LLVM.IRBuilder()
+                        position!(b, inst)
+                        replace_uses!(inst, unsafe_to_llvm(b, obj))
+                        LLVM.API.LLVMInstructionEraseFromParent(inst)
+                        continue
+                    end
+                end
+            end
+        end
     end
 end
-end
 
 function removeDeadArgs!(mod::LLVM.Module, tm::LLVM.TargetMachine)
     # We need to run globalopt first. This is because remove dead args will otherwise

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Sep 11, 2025

Benchmark Results

main e589508... main / e589508...
basics/overhead 4.34 ± 0.01 ns 4.64 ± 0.01 ns 0.935 ± 0.003
time_to_load 1.24 ± 0.03 s 1.26 ± 0.015 s 0.985 ± 0.027

Benchmark Plots

A plot of the benchmark results has been uploaded as an artifact at https://github.com/EnzymeAD/Enzyme.jl/actions/runs/17654587380/artifacts/3989541443.

@codecov
Copy link
Copy Markdown

codecov Bot commented Sep 11, 2025

Codecov Report

❌ Patch coverage is 92.00000% with 4 lines in your changes missing coverage. Please review.
✅ Project coverage is 75.00%. Comparing base (701f5f9) to head (e589508).
⚠️ Report is 2 commits behind head on main.

Files with missing lines Patch % Lines
src/absint.jl 86.66% 2 Missing ⚠️
src/errors.jl 93.33% 1 Missing ⚠️
src/llvm/transforms.jl 94.73% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2561      +/-   ##
==========================================
+ Coverage   74.93%   75.00%   +0.06%     
==========================================
  Files          56       56              
  Lines       17445    17494      +49     
==========================================
+ Hits        13073    13121      +48     
- Misses       4372     4373       +1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Comment thread src/errors.jl
Comment on lines +532 to +548
ptr = reinterpret(Ptr{Cvoid}, convert(UInt, larg))
obj = Base.unsafe_pointer_to_objref(ptr)
if obj isa Memory && obj == typeof(obj).instance
return make_batched(ncur, prevbb)
end
end
end
if isa(cur, LLVM.ConstantExpr) && isa(value_type(cur), LLVM.PointerType) && LLVM.addrspace(value_type(cur)) == Derived
larg, off = get_base_and_offset(cur; inst=ncur, inttoptr=true)
if isa(larg, LLVM.ConstantInt) && off == sizeof(Int)
ptr = reinterpret(Ptr{Cvoid}, convert(UInt, larg))
obj = Base.unsafe_pointer_to_objref(ptr)
if obj isa Memory && obj == typeof(obj).instance
return make_batched(ncur, prevbb)
end
end
end
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you need to preserve the pointers? (side note, indentation looks off)

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope pointers don't need preservation as they were baked into Ir so already preserved from Julia.

As for fmt, honestly my preference would be setting up a nightly auto format pr like we have in reactant

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am okay with that, but it would break the open PRs

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can wait to merge that until any specific big ones are in [or only for certain files]

@wsmoses wsmoses merged commit 10b2caa into main Sep 11, 2025
30 of 36 checks passed
@wsmoses wsmoses deleted the mi2 branch September 11, 2025 21:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants