Skip to content

Commit 1f02813

Browse files
committed
Add assertion about strides.dtype in __new__
1 parent 1077868 commit 1f02813

1 file changed

Lines changed: 2 additions & 0 deletions

File tree

src/torchjd/sparse/_structured_sparse_tensor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ def __new__(cls, physical: Tensor, strides: Tensor):
1818
# that the tensor we're wrapping is exactly a Tensor
1919
assert type(physical) is Tensor
2020

21+
assert strides.dtype == torch.int64
22+
2123
# Note [Passing requires_grad=true tensors to subclasses]
2224
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
2325
# Calling _make_subclass directly in an autograd context is

0 commit comments

Comments
 (0)