@@ -28,8 +28,8 @@ def view_default(t: SparseLatticedTensor, shape: list[int]) -> Tensor:
2828 c.T = [prod(t.shape[1:]), prod(t.shape[2:]), ..., t.shape[-1], 1]
2929 * c' is the same thing but after the reshape, i.e.
3030 c'.T = [prod(shape[1:]), prod(shape[2:]), ..., shape[-1], 1]
31- * S is the original matrix of strides (t.strides )
32- * S' is the matrix of strides after reshaping.
31+ * S is the original basis matrix (t.basis )
32+ * S' is the basis matrix after reshaping.
3333
3434 For u, v in Z^m and c in Z, say that u ≡ v (mod c) if u_i ≡ v_i (mod c) for all i.
3535 Note that c'.T S' ≡ S'[-1] (mod shape[-1])
@@ -46,12 +46,12 @@ def view_default(t: SparseLatticedTensor, shape: list[int]) -> Tensor:
4646 if prod (shape ) != t .numel ():
4747 raise ValueError (f"shape '{ shape } ' is invalid for input of size { t .numel ()} " )
4848
49- S = t .strides
49+ S = t .basis
5050 vshape = list (t .shape )
5151 c = _reverse_cumulative_product (vshape )
5252 c_prime = _reverse_cumulative_product (shape )
53- new_strides = ((c @ S ).unsqueeze (0 ) // c_prime .unsqueeze (1 )) % tensor (shape ).unsqueeze (1 )
54- return to_most_efficient_tensor (t .physical , new_strides )
53+ new_basis = ((c @ S ).unsqueeze (0 ) // c_prime .unsqueeze (1 )) % tensor (shape ).unsqueeze (1 )
54+ return to_most_efficient_tensor (t .physical , new_basis )
5555
5656
5757def _reverse_cumulative_product (values : list [int ]) -> Tensor :
@@ -70,7 +70,7 @@ def infer_shape(shape: list[int], numel: int) -> list[int]:
7070
7171
7272def unsquash_pdim (
73- physical : Tensor , strides : Tensor , pdim : int , new_pdim_shape : list [int ]
73+ physical : Tensor , basis : Tensor , pdim : int , new_pdim_shape : list [int ]
7474) -> tuple [Tensor , Tensor ]:
7575 """
7676 EXAMPLE:
@@ -80,7 +80,7 @@ def unsquash_pdim(
8080 [7, 8, 9, 10, 11, 12],
8181 [13, 14, 15, 16, 17, 18],
8282 ]
83- strides = [
83+ basis = [
8484 [1, 1],
8585 [0, 2],
8686 ]
@@ -99,7 +99,7 @@ def unsquash_pdim(
9999 [16, 17, 18],
100100 ]]
101101
102- new_strides = [
102+ new_basis = [
103103 [1, 3, 1],
104104 [0, 6, 2]
105105 """
@@ -110,18 +110,18 @@ def unsquash_pdim(
110110 new_shape = old_shape [:pdim ] + new_pdim_shape + old_shape [pdim + 1 :]
111111 new_physical = physical .reshape (new_shape )
112112
113- stride_multipliers = tensor ([prod (new_pdim_shape [i + 1 :]) for i in range (len (new_pdim_shape ))])
113+ multipliers = tensor ([prod (new_pdim_shape [i + 1 :]) for i in range (len (new_pdim_shape ))])
114114
115- new_strides = torch .concat (
115+ new_basis = torch .concat (
116116 [
117- strides [:, :pdim ],
118- torch .outer (strides [:, pdim ], stride_multipliers ),
119- strides [:, pdim + 1 :],
117+ basis [:, :pdim ],
118+ torch .outer (basis [:, pdim ], multipliers ),
119+ basis [:, pdim + 1 :],
120120 ],
121121 dim = 1 ,
122122 )
123123
124- return new_physical , new_strides
124+ return new_physical , new_basis
125125
126126
127127@impl (aten ._unsafe_view .default )
@@ -139,10 +139,10 @@ def unsqueeze_default(t: SparseLatticedTensor, dim: int) -> SparseLatticedTensor
139139 if dim < 0 :
140140 dim = t .ndim + dim + 1
141141
142- new_strides = torch .concatenate (
143- [t .strides [:dim ], torch .zeros (1 , t .strides .shape [1 ], dtype = torch .int64 ), t .strides [dim :]]
142+ new_basis = torch .concatenate (
143+ [t .basis [:dim ], torch .zeros (1 , t .basis .shape [1 ], dtype = torch .int64 ), t .basis [dim :]]
144144 )
145- return SparseLatticedTensor (t .physical , new_strides )
145+ return SparseLatticedTensor (t .physical , new_basis )
146146
147147
148148@impl (aten .squeeze .dims )
@@ -157,14 +157,14 @@ def squeeze_dims(t: SparseLatticedTensor, dims: list[int] | int | None) -> Tenso
157157 excluded = set (dims )
158158
159159 is_row_kept = [i not in excluded for i in range (t .ndim )]
160- new_strides = t .strides [is_row_kept ]
161- return to_most_efficient_tensor (t .physical , new_strides )
160+ new_basis = t .basis [is_row_kept ]
161+ return to_most_efficient_tensor (t .physical , new_basis )
162162
163163
164164@impl (aten .permute .default )
165165def permute_default (t : SparseLatticedTensor , dims : list [int ]) -> SparseLatticedTensor :
166- new_strides = t .strides [torch .tensor (dims )]
167- return SparseLatticedTensor (t .physical , new_strides )
166+ new_basis = t .basis [torch .tensor (dims )]
167+ return SparseLatticedTensor (t .physical , new_basis )
168168
169169
170170@impl (aten .cat .default )
@@ -175,11 +175,11 @@ def cat_default(tensors: list[Tensor], dim: int) -> Tensor:
175175
176176 tensors_ = [cast (SparseLatticedTensor , t ) for t in tensors ]
177177 ref_tensor = tensors_ [0 ]
178- ref_strides = ref_tensor .strides
179- if any (not torch .equal (t .strides , ref_strides ) for t in tensors_ [1 :]):
178+ ref_basis = ref_tensor .basis
179+ if any (not torch .equal (t .basis , ref_basis ) for t in tensors_ [1 :]):
180180 raise NotImplementedError (
181181 "Override for aten.cat.default does not support SSTs that do not all have the same "
182- f"strides . Found the following tensors:\n { [t .debug_info () for t in tensors_ ]} and the "
182+ f"basis . Found the following tensors:\n { [t .debug_info () for t in tensors_ ]} and the "
183183 f"following dim: { dim } ."
184184 )
185185
@@ -189,8 +189,8 @@ def cat_default(tensors: list[Tensor], dim: int) -> Tensor:
189189
190190 ref_virtual_dim_size = ref_tensor .shape [dim ]
191191 indices = torch .argwhere (
192- torch .eq (ref_strides [dim ] * tensor (ref_tensor .physical .shape ), ref_virtual_dim_size )
193- & torch .eq (ref_strides .sum (dim = 0 ) * tensor (ref_tensor .physical .shape ), ref_virtual_dim_size )
192+ torch .eq (ref_basis [dim ] * tensor (ref_tensor .physical .shape ), ref_virtual_dim_size )
193+ & torch .eq (ref_basis .sum (dim = 0 ) * tensor (ref_tensor .physical .shape ), ref_virtual_dim_size )
194194 )
195195 assert len (indices ) <= 1
196196
@@ -200,18 +200,18 @@ def cat_default(tensors: list[Tensor], dim: int) -> Tensor:
200200
201201 pdim = ref_tensor .physical .ndim
202202 physicals = [t .physical .unsqueeze (- 1 ) for t in tensors_ ]
203- new_stride_column = torch .zeros (ref_tensor .ndim , 1 , dtype = torch .int64 )
204- new_stride_column [dim , 0 ] = ref_virtual_dim_size
205- new_strides = torch .concatenate ([ref_tensor .strides , new_stride_column ], dim = 1 )
203+ new_basis_vector = torch .zeros (ref_tensor .ndim , 1 , dtype = torch .int64 )
204+ new_basis_vector [dim , 0 ] = ref_virtual_dim_size
205+ new_basis = torch .concatenate ([ref_tensor .basis , new_basis_vector ], dim = 1 )
206206 else :
207207 # Such a physical dimension already exists. Note that an alternative implementation would be
208208 # to simply always add the physical dimension, and squash it if it ends up being not needed.
209209 physicals = [t .physical for t in tensors_ ]
210210 pdim = cast (int , indices [0 , 0 ].item ())
211- new_strides = ref_tensor .strides
211+ new_basis = ref_tensor .basis
212212
213213 new_physical = aten .cat .default (physicals , dim = pdim )
214- return SparseLatticedTensor (new_physical , new_strides )
214+ return SparseLatticedTensor (new_physical , new_basis )
215215
216216
217217@impl (aten .expand .default )
@@ -227,26 +227,26 @@ def expand_default(t: SparseLatticedTensor, sizes: list[int]) -> SparseLatticedT
227227
228228 # Try to expand each dimension to its new size
229229 new_physical = t .physical
230- new_strides = t .strides
231- for d , (vstride , orig_size , new_size ) in enumerate (zip (t .strides , t .shape , sizes , strict = True )):
232- if vstride .sum () > 0 and orig_size != new_size and new_size != - 1 :
230+ new_basis = t .basis
231+ for d , (v , orig_size , new_size ) in enumerate (zip (t .basis , t .shape , sizes , strict = True )):
232+ if v .sum () > 0 and orig_size != new_size and new_size != - 1 :
233233 raise ValueError (
234234 f"Cannot expand dim { d } of size != 1. Found size { orig_size } and target size "
235235 f"{ new_size } ."
236236 )
237237
238- if vstride .sum () == 0 and new_size != 1 and new_size != - 1 :
238+ if v .sum () == 0 and new_size != 1 and new_size != - 1 :
239239 # Add a dimension of size new_size at the end of the physical tensor.
240240 new_physical_shape = list (new_physical .shape ) + [new_size ]
241241 new_physical = new_physical .unsqueeze (- 1 ).expand (new_physical_shape )
242242
243- # Make this new physical dimension have a stride of 1 at virtual dimension d and 0 at
244- # every other virtual dimension
245- new_stride_column = torch .zeros (t .ndim , 1 , dtype = torch .int64 )
246- new_stride_column [d , 0 ] = 1
247- new_strides = torch .cat ([new_strides , new_stride_column ], dim = 1 )
243+ # Make the basis vector of this new physical dimension be 1 at virtual dimension d and 0
244+ # at every other virtual dimension
245+ new_basis_vector = torch .zeros (t .ndim , 1 , dtype = torch .int64 )
246+ new_basis_vector [d , 0 ] = 1
247+ new_basis = torch .cat ([new_basis , new_basis_vector ], dim = 1 )
248248
249- return SparseLatticedTensor (new_physical , new_strides )
249+ return SparseLatticedTensor (new_physical , new_basis )
250250
251251
252252@impl (aten .broadcast_tensors .default )
@@ -279,7 +279,7 @@ def broadcast_tensors_default(tensors: list[Tensor]) -> tuple[Tensor, Tensor]:
279279@impl (aten .transpose .int )
280280def transpose_int (t : SparseLatticedTensor , dim0 : int , dim1 : int ) -> SparseLatticedTensor :
281281 assert isinstance (t , SparseLatticedTensor )
282- return SparseLatticedTensor (t .physical , _swap_rows (t .strides , dim0 , dim1 ))
282+ return SparseLatticedTensor (t .physical , _swap_rows (t .basis , dim0 , dim1 ))
283283
284284
285285def _swap_rows (matrix : Tensor , c0 : int , c1 : int ) -> Tensor :
0 commit comments