Skip to content

Commit b1e92e5

Browse files
committed
Fix dtype in fix_dim_of_size_1
1 parent 9d4f2c8 commit b1e92e5

1 file changed

Lines changed: 1 addition & 1 deletion

File tree

src/torchjd/sparse/_structured_sparse_tensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ def get_full_source(source: list[int], destination: list[int], ndim: int) -> lis
268268

269269

270270
def fix_dim_of_size_1(physical: Tensor, strides: Tensor) -> tuple[Tensor, Tensor]:
271-
is_of_size_1 = torch.tensor([s == 1 for s in physical.shape])
271+
is_of_size_1 = torch.tensor([s == 1 for s in physical.shape], dtype=torch.bool)
272272
return physical.squeeze(), strides[:, ~is_of_size_1]
273273

274274

0 commit comments

Comments
 (0)