Skip to content

Commit ce358d0

Browse files
committed
Fix formatting
1 parent e885361 commit ce358d0

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

src/torchjd/sparse/_structured_sparse_tensor.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -279,10 +279,14 @@ def make_sst(physical: Tensor, strides: Tensor) -> StructuredSparseTensor:
279279

280280

281281
def fix_zero_stride_columns(physical: Tensor, strides: Tensor) -> tuple[Tensor, Tensor]:
282-
"""Remove columns of strides that are all 0 and sum the corresponding elements in the physical tensor."""
282+
"""
283+
Remove columns of strides that are all 0 and sum the corresponding elements in the physical
284+
tensor.
285+
"""
286+
283287
are_columns_zero = (strides == 0).all(dim=0)
284288

285-
if not (are_columns_zero).any():
289+
if not are_columns_zero.any():
286290
return physical, strides
287291

288292
zero_column_indices = torch.arange(len(are_columns_zero))[are_columns_zero].tolist()

0 commit comments

Comments
 (0)