Skip to content

Commit f54c327

Browse files
committed
Fix to_most_efficient_tensor
1 parent 1f02813 commit f54c327

1 file changed

Lines changed: 2 additions & 18 deletions

File tree

src/torchjd/sparse/_structured_sparse_tensor.py

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -216,24 +216,8 @@ def to_most_efficient_tensor(physical: Tensor, strides: Tensor) -> Tensor:
216216
physical, strides = fix_ungrouped_dims(physical, strides)
217217

218218
if (strides.sum(dim=0) == 1).all():
219-
# All physical dimensions make you move by 1 in exactly 1 virtual dimension.
220-
# Also, because all physical dimensions have been maximally grouped, we cannot have two
221-
# physical dimensions that make you move in the same virtual dimension.
222-
# So strides is an identity matrix with potentially some extra rows of zeros, and
223-
# potentially shuffled columns.
224-
225-
# The first step is to unsqueeze the physical tensor for each extra row of zeros in the
226-
# strides.
227-
zero_row_mask = strides.sum(dim=1) == 0
228-
number_of_zero_rows = zero_row_mask.sum()
229-
for _ in number_of_zero_rows:
230-
physical = physical.unsqueeze(-1)
231-
232-
# The second step is to re-order the physical dimensions so that the corresponding
233-
# strides matrix would be an identity.
234-
source = arange(strides.shape[0])
235-
destination = strides[zero_row_mask] @ source
236-
return physical.movedim(list(source), list(destination))
219+
# TODO: this can be done more efficiently (without even creating the SST)
220+
return StructuredSparseTensor(physical, strides).to_dense()
237221
else:
238222
return StructuredSparseTensor(physical, strides)
239223

0 commit comments

Comments
 (0)