We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 5595f5d commit 628e0b0Copy full SHA for 628e0b0
1 file changed
src/torchjd/autogram/_engine.py
@@ -319,11 +319,9 @@ def differentiation(_grad_output: Tensor) -> tuple[Tensor, ...]:
319
if has_non_batch_dim:
320
# There is one non-batched dimension, it is the first one
321
non_batch_dim_len = output.shape[0]
322
- jac_output_shape = [output.shape[0]] + list(output.shape)
323
-
324
- jac_output = torch.zeros(jac_output_shape, device=output.device, dtype=output.dtype)
325
- for i in range(non_batch_dim_len):
326
- jac_output[i, i, ...] = 1.0
+ identity_matrix = torch.eye(non_batch_dim_len, device=output.device, dtype=output.dtype)
+ ones = torch.ones_like(output[0])
+ jac_output = torch.einsum("ij, ... -> ij...", identity_matrix, ones)
327
328
_ = vmap(differentiation)(jac_output)
329
else:
0 commit comments