Skip to content

Register reverse mode AD for memref.alloca_scope#2801

Closed
xys-syx wants to merge 3 commits intoEnzymeAD:mainfrom
xys-syx:alloca_scope
Closed

Register reverse mode AD for memref.alloca_scope#2801
xys-syx wants to merge 3 commits intoEnzymeAD:mainfrom
xys-syx:alloca_scope

Conversation

@xys-syx
Copy link
Copy Markdown
Collaborator

@xys-syx xys-syx commented Apr 30, 2026

Try to fix #2791

@xys-syx xys-syx marked this pull request as ready for review May 7, 2026 03:49
@xys-syx xys-syx requested review from vimarsh6739 and wsmoses May 7, 2026 03:49
@vimarsh6739
Copy link
Copy Markdown
Member

vimarsh6739 commented May 7, 2026

I dont think this will work if the alloca_scope has a memref.alloca inside it (caching it will be unsafe across the 2 alloca_scope ops that are emitted)

Something like this will fail

func.func @foo(%x : f64) -> f64{
    %out = memref.alloca_scope -> (f64) {
      %buf = memref.alloca() : memref<f64>
      memref.store %x, %buf[] : memref<f64>
      %y = memref.load %buf[] : memref<f64>
      memref.alloca_scope.return %y : f64
    }
    return %out
}

func.func @dfoo(%x, %dout) {
   %dx = enzyme.autodiff @foo(%x,%dout) {act=[active], ret_act=[activenoneed] } 
   return %dx
}

The primal is basically

out = alloca_scope{
buf[] = x
y = buf[]
return y
}

Reverse Mode should give us

dy += dout
d(buf) += dy // load
dx += d(buf) // store
d(buf) = 0 //alloca

We want this to lower to something like this (need to discuss what the caching should look like, since I dont think we can just push the memref to the cache here)

//forward
memref.alloca_scope {
    %buf = memref.alloca() 
    %dbuf = memref.alloca() 
    zero %dbuf
    
    enzyme.push %cache_store %dbuf
    memref.store %x, %buf
    
    enzyme.push %cache_load %dbuf
    %y = memref.load %buf
    
     alloca_scope.return %y
}

//reverse
memref.alloca_scope{
  //terminator
   dy += dout
   //load
   dbuf1 = enzyme.pop %cache_load
   old = memref.load dbuf1
   memref.store  (old + dy), dbuf1 
   //store
   dbuf2 = enzyme.pop %cache_store
   %old2 = memref.load dbuf2
   dx += old2
   
   memref.store 0, dbuf2
}

The issue is that the cache push/pop will not be valid across the 2 alloca_scopes here

@vimarsh6739 vimarsh6739 closed this May 7, 2026
@vimarsh6739
Copy link
Copy Markdown
Member

#2814

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

memref.alloca_scope and its teminator memref.alloca_scope.return does not register reverse mode AD

2 participants