Skip to content

Commit 72b382a

Browse files
committed
Move fix_zero_stride_columns to _coalesce.py
1 parent ce358d0 commit 72b382a

File tree

3 files changed

+20
-18
lines changed

3 files changed

+20
-18
lines changed

src/torchjd/sparse/_coalesce.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import torch
2+
from torch import Tensor
3+
4+
5+
def fix_zero_stride_columns(physical: Tensor, strides: Tensor) -> tuple[Tensor, Tensor]:
6+
"""
7+
Remove columns of strides that are all 0 and sum the corresponding elements in the physical
8+
tensor.
9+
"""
10+
11+
are_columns_zero = (strides == 0).all(dim=0)
12+
13+
if not are_columns_zero.any():
14+
return physical, strides
15+
16+
zero_column_indices = torch.arange(len(are_columns_zero))[are_columns_zero].tolist()
17+
physical = physical.sum(dim=zero_column_indices)
18+
strides = strides[:, ~are_columns_zero]
19+
return physical, strides

src/torchjd/sparse/_structured_sparse_tensor.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -276,20 +276,3 @@ def make_sst(physical: Tensor, strides: Tensor) -> StructuredSparseTensor:
276276
physical, strides = fix_dim_of_size_1(physical, strides)
277277
physical, strides = fix_ungrouped_dims(physical, strides)
278278
return StructuredSparseTensor(physical, strides)
279-
280-
281-
def fix_zero_stride_columns(physical: Tensor, strides: Tensor) -> tuple[Tensor, Tensor]:
282-
"""
283-
Remove columns of strides that are all 0 and sum the corresponding elements in the physical
284-
tensor.
285-
"""
286-
287-
are_columns_zero = (strides == 0).all(dim=0)
288-
289-
if not are_columns_zero.any():
290-
return physical, strides
291-
292-
zero_column_indices = torch.arange(len(are_columns_zero))[are_columns_zero].tolist()
293-
physical = physical.sum(dim=zero_column_indices)
294-
strides = strides[:, ~are_columns_zero]
295-
return physical, strides

tests/unit/sparse/test_structured_sparse_tensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@
1111
_POINTWISE_FUNCTIONS,
1212
)
1313
from torchjd.sparse._aten_function_overrides.shape import unsquash_pdim
14+
from torchjd.sparse._coalesce import fix_zero_stride_columns
1415
from torchjd.sparse._structured_sparse_tensor import (
1516
StructuredSparseTensor,
1617
fix_ungrouped_dims,
17-
fix_zero_stride_columns,
1818
get_full_source,
1919
get_groupings,
2020
)

0 commit comments

Comments
 (0)