@@ -65,7 +65,7 @@ def hook_module(self, module: nn.Module) -> None:
6565 self ._gramian_accumulator ,
6666 self ._has_batch_dim ,
6767 )
68- self ._handles .append (module .register_forward_hook (hook ))
68+ self ._handles .append (module .register_forward_hook (hook , with_kwargs = True ))
6969
7070 @staticmethod
7171 def remove_hooks (handles : list [TorchRemovableHandle ]) -> None :
@@ -101,7 +101,13 @@ def __init__(
101101 self .gramian_accumulator = gramian_accumulator
102102 self .has_batch_dim = has_batch_dim
103103
104- def __call__ (self , module : nn .Module , args : tuple [PyTree , ...], outputs : PyTree ) -> PyTree :
104+ def __call__ (
105+ self ,
106+ module : nn .Module ,
107+ args : tuple [PyTree , ...],
108+ kwargs : dict [str , PyTree ],
109+ outputs : PyTree ,
110+ ) -> PyTree :
105111 if self .gramian_accumulation_phase :
106112 return outputs
107113
@@ -131,9 +137,10 @@ def __call__(self, module: nn.Module, args: tuple[PyTree, ...], outputs: PyTree)
131137
132138 vjp : VJP
133139 if self .has_batch_dim :
134- rg_outputs_in_dims = (0 ,) * len (rg_outputs )
135- args_in_dims = tree_map (lambda t : 0 if isinstance (t , Tensor ) else None , args )
136- in_dims = (rg_outputs_in_dims , args_in_dims )
140+ rg_output_in_dims = (0 ,) * len (rg_outputs )
141+ arg_in_dims = tree_map (lambda t : 0 if isinstance (t , Tensor ) else None , args )
142+ kwargs_in_dims = tree_map (lambda t : 0 if isinstance (t , Tensor ) else None , kwargs )
143+ in_dims = (rg_output_in_dims , arg_in_dims , kwargs_in_dims )
137144 vjp = FunctionalVJP (module , in_dims )
138145 else :
139146 vjp = AutogradVJP (module , rg_outputs )
@@ -142,6 +149,7 @@ def __call__(self, module: nn.Module, args: tuple[PyTree, ...], outputs: PyTree)
142149 self .gramian_accumulation_phase ,
143150 vjp ,
144151 args ,
152+ kwargs ,
145153 self .gramian_accumulator ,
146154 module ,
147155 * rg_outputs ,
@@ -169,14 +177,15 @@ def forward(
169177 gramian_accumulation_phase : BoolRef ,
170178 vjp : VJP ,
171179 args : tuple [PyTree , ...],
180+ kwargs : dict [str , PyTree ],
172181 gramian_accumulator : GramianAccumulator ,
173182 module : nn .Module ,
174183 * rg_tensors : Tensor ,
175184 ) -> tuple [Tensor , ...]:
176185 return tuple (t .detach () for t in rg_tensors )
177186
178187 # For Python version > 3.10, the type of `inputs` should become
179- # tuple[BoolRef, VJP, tuple[PyTree, ...], GramianAccumulator, nn.Module, *tuple[Tensor, ...]]
188+ # tuple[BoolRef, VJP, tuple[PyTree, ...], dict[str, PyTree], GramianAccumulator, nn.Module, *tuple[Tensor, ...]]
180189 @staticmethod
181190 def setup_context (
182191 ctx ,
@@ -186,25 +195,27 @@ def setup_context(
186195 ctx .gramian_accumulation_phase = inputs [0 ]
187196 ctx .vjp = inputs [1 ]
188197 ctx .args = inputs [2 ]
189- ctx .gramian_accumulator = inputs [3 ]
190- ctx .module = inputs [4 ]
198+ ctx .kwargs = inputs [3 ]
199+ ctx .gramian_accumulator = inputs [4 ]
200+ ctx .module = inputs [5 ]
191201
192202 @staticmethod
193203 def backward (ctx , * grad_outputs : Tensor ) -> tuple :
194- # Return type for python > 3.10: # tuple[None, None, None, None, None, *tuple[Tensor, ...]]
204+ # For python > 3.10: -> tuple[None, None, None, None, None, None, *tuple[Tensor, ...]]
195205
196206 if not ctx .gramian_accumulation_phase :
197- return None , None , None , None , None , * grad_outputs
207+ return None , None , None , None , None , None , * grad_outputs
198208
199209 AccumulateJacobian .apply (
200210 ctx .vjp ,
201211 ctx .args ,
212+ ctx .kwargs ,
202213 ctx .gramian_accumulator ,
203214 ctx .module ,
204215 * grad_outputs ,
205216 )
206217
207- return None , None , None , None , None , * grad_outputs
218+ return None , None , None , None , None , None , * grad_outputs
208219
209220
210221class AccumulateJacobian (torch .autograd .Function ):
@@ -213,29 +224,31 @@ class AccumulateJacobian(torch.autograd.Function):
213224 def forward (
214225 vjp : VJP ,
215226 args : tuple [PyTree , ...],
227+ kwargs : dict [str , PyTree ],
216228 gramian_accumulator : GramianAccumulator ,
217229 module : nn .Module ,
218230 * grad_outputs : Tensor ,
219231 ) -> None :
220232 # There is no non-batched dimension
221- generalized_jacobians = vjp (grad_outputs , args )
233+ generalized_jacobians = vjp (grad_outputs , args , kwargs )
222234 path_jacobians = AccumulateJacobian ._make_path_jacobians (module , generalized_jacobians )
223235 gramian_accumulator .accumulate_path_jacobians (path_jacobians )
224236
225237 @staticmethod
226238 def vmap (
227239 _ ,
228- in_dims : tuple , # tuple[None, tuple[PyTree, ...], None, None, *tuple[int | None, ...]]
240+ in_dims : tuple , # tuple[None, tuple[PyTree, ...], dict[str, PyTree], None, None, *tuple[int | None, ...]]
229241 vjp : VJP ,
230242 args : tuple [PyTree , ...],
243+ kwargs : dict [str , PyTree ],
231244 gramian_accumulator : GramianAccumulator ,
232245 module : nn .Module ,
233246 * jac_outputs : Tensor ,
234247 ) -> tuple [None , None ]:
235248 # There is a non-batched dimension
236249 # We do not vmap over the args for the non-batched dimension
237- in_dims = (in_dims [4 :], tree_map (lambda _ : None , args ))
238- generalized_jacobians = torch .vmap (vjp , in_dims = in_dims )(jac_outputs , args )
250+ in_dims = (in_dims [5 :], tree_map (lambda _ : None , args ), tree_map ( lambda _ : None , kwargs ))
251+ generalized_jacobians = torch .vmap (vjp , in_dims = in_dims )(jac_outputs , args , kwargs )
239252 path_jacobians = AccumulateJacobian ._make_path_jacobians (module , generalized_jacobians )
240253 gramian_accumulator .accumulate_path_jacobians (path_jacobians )
241254 return None , None
0 commit comments