Skip to content

Commit 1077868

Browse files
committed
Fix creation of int tensors
1 parent b1e92e5 commit 1077868

1 file changed

Lines changed: 5 additions & 5 deletions

File tree

src/torchjd/sparse/_structured_sparse_tensor.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def __new__(cls, physical: Tensor, strides: Tensor):
2828
# (which is bad!)
2929
assert not physical.requires_grad or not torch.is_grad_enabled()
3030

31-
pshape = torch.tensor(physical.shape)
31+
pshape = tensor(physical.shape, dtype=torch.int64)
3232
vshape = strides @ (pshape - 1) + 1
3333
return Tensor._make_wrapper_subclass(
3434
cls, tuple(vshape.tolist()), dtype=physical.dtype, device=physical.device
@@ -186,7 +186,7 @@ def strides_v2(p_dims: list[int], physical_shape: list[int]) -> list[int]:
186186

187187

188188
def get_groupings(pshape: list[int], strides: Tensor) -> list[list[int]]:
189-
strides_time_pshape = strides * tensor(pshape)
189+
strides_time_pshape = strides * tensor(pshape, dtype=torch.int64)
190190
groups = {i: {i} for i, column in enumerate(strides.T)}
191191
group_ids = [i for i in range(len(strides.T))]
192192
for i1, i2 in itertools.combinations(range(strides.shape[1]), 2):
@@ -260,15 +260,15 @@ def get_full_source(source: list[int], destination: list[int], ndim: int) -> lis
260260
"""
261261

262262
idx = torch.full((ndim,), -1, dtype=torch.int64)
263-
idx[destination] = tensor(source)
263+
idx[destination] = tensor(source, dtype=torch.int64)
264264
source_set = set(source)
265-
idx[idx.eq(-1)] = tensor([i for i in range(ndim) if i not in source_set])
265+
idx[idx.eq(-1)] = tensor([i for i in range(ndim) if i not in source_set], dtype=torch.int64)
266266

267267
return idx.tolist()
268268

269269

270270
def fix_dim_of_size_1(physical: Tensor, strides: Tensor) -> tuple[Tensor, Tensor]:
271-
is_of_size_1 = torch.tensor([s == 1 for s in physical.shape], dtype=torch.bool)
271+
is_of_size_1 = tensor([s == 1 for s in physical.shape], dtype=torch.bool)
272272
return physical.squeeze(), strides[:, ~is_of_size_1]
273273

274274

0 commit comments

Comments
 (0)