@@ -28,7 +28,7 @@ def __new__(cls, physical: Tensor, strides: Tensor):
2828 # (which is bad!)
2929 assert not physical .requires_grad or not torch .is_grad_enabled ()
3030
31- pshape = torch . tensor (physical .shape )
31+ pshape = tensor (physical .shape , dtype = torch . int64 )
3232 vshape = strides @ (pshape - 1 ) + 1
3333 return Tensor ._make_wrapper_subclass (
3434 cls , tuple (vshape .tolist ()), dtype = physical .dtype , device = physical .device
@@ -186,7 +186,7 @@ def strides_v2(p_dims: list[int], physical_shape: list[int]) -> list[int]:
186186
187187
188188def get_groupings (pshape : list [int ], strides : Tensor ) -> list [list [int ]]:
189- strides_time_pshape = strides * tensor (pshape )
189+ strides_time_pshape = strides * tensor (pshape , dtype = torch . int64 )
190190 groups = {i : {i } for i , column in enumerate (strides .T )}
191191 group_ids = [i for i in range (len (strides .T ))]
192192 for i1 , i2 in itertools .combinations (range (strides .shape [1 ]), 2 ):
@@ -260,15 +260,15 @@ def get_full_source(source: list[int], destination: list[int], ndim: int) -> lis
260260 """
261261
262262 idx = torch .full ((ndim ,), - 1 , dtype = torch .int64 )
263- idx [destination ] = tensor (source )
263+ idx [destination ] = tensor (source , dtype = torch . int64 )
264264 source_set = set (source )
265- idx [idx .eq (- 1 )] = tensor ([i for i in range (ndim ) if i not in source_set ])
265+ idx [idx .eq (- 1 )] = tensor ([i for i in range (ndim ) if i not in source_set ], dtype = torch . int64 )
266266
267267 return idx .tolist ()
268268
269269
270270def fix_dim_of_size_1 (physical : Tensor , strides : Tensor ) -> tuple [Tensor , Tensor ]:
271- is_of_size_1 = torch . tensor ([s == 1 for s in physical .shape ], dtype = torch .bool )
271+ is_of_size_1 = tensor ([s == 1 for s in physical .shape ], dtype = torch .bool )
272272 return physical .squeeze (), strides [:, ~ is_of_size_1 ]
273273
274274
0 commit comments