@@ -216,24 +216,8 @@ def to_most_efficient_tensor(physical: Tensor, strides: Tensor) -> Tensor:
216216 physical , strides = fix_ungrouped_dims (physical , strides )
217217
218218 if (strides .sum (dim = 0 ) == 1 ).all ():
219- # All physical dimensions make you move by 1 in exactly 1 virtual dimension.
220- # Also, because all physical dimensions have been maximally grouped, we cannot have two
221- # physical dimensions that make you move in the same virtual dimension.
222- # So strides is an identity matrix with potentially some extra rows of zeros, and
223- # potentially shuffled columns.
224-
225- # The first step is to unsqueeze the physical tensor for each extra row of zeros in the
226- # strides.
227- zero_row_mask = strides .sum (dim = 1 ) == 0
228- number_of_zero_rows = zero_row_mask .sum ()
229- for _ in number_of_zero_rows :
230- physical = physical .unsqueeze (- 1 )
231-
232- # The second step is to re-order the physical dimensions so that the corresponding
233- # strides matrix would be an identity.
234- source = arange (strides .shape [0 ])
235- destination = strides [zero_row_mask ] @ source
236- return physical .movedim (list (source ), list (destination ))
219+ # TODO: this can be done more efficiently (without even creating the SST)
220+ return StructuredSparseTensor (physical , strides ).to_dense ()
237221 else :
238222 return StructuredSparseTensor (physical , strides )
239223
0 commit comments