22from torch import Tensor , tensor
33from torch .ops import aten # type: ignore
44
5- from torchjd .sparse ._structured_sparse_tensor import (
6- StructuredSparseTensor ,
5+ from torchjd .sparse ._sparse_latticed_tensor import (
6+ SparseLatticedTensor ,
77 impl ,
88 to_most_efficient_tensor ,
9- to_structured_sparse_tensor ,
9+ to_sparse_latticed_tensor ,
1010)
1111
1212
13- def einsum (* args : tuple [StructuredSparseTensor , list [int ]], output : list [int ]) -> Tensor :
13+ def einsum (* args : tuple [SparseLatticedTensor , list [int ]], output : list [int ]) -> Tensor :
1414 raise NotImplementedError ()
1515
1616 # First part of the algorithm, determine how to cluster physical indices as well as the common
@@ -39,13 +39,13 @@ def einsum(*args: tuple[StructuredSparseTensor, list[int]], output: list[int]) -
3939 # [p_1, ..., p_k], then we have to create fresh sub-indices for each dimension.
4040 # For this reason, an index is decomposed into sub-indices that are then independently
4141 # clustered.
42- # So if an index i in args for some StructuredSparseTensor corresponds to a v_to_ps [j, k, l],
42+ # So if an index i in args for some SparseLatticedTensor corresponds to a v_to_ps [j, k, l],
4343 # We will consider three indices (i, 0), (i, 1) and (i, 2).
4444 # If furthermore [k] correspond to the v_to_ps of some other tensor with index j, then
4545 # (i, 1) and (j, 0) will be clustered together (and end up being mapped to the same indice in
4646 # the resulting einsum).
4747 # Note that this is a problem if two virtual dimensions (from possibly different
48- # StructuredSparseTensors ) have the same size but not the same decomposition into physical
48+ # SparseLatticedTensors ) have the same size but not the same decomposition into physical
4949 # dimension sizes. For now lets leave the responsibility to care about that in the calling
5050 # functions, if we can factor code later on we will.
5151
@@ -71,7 +71,7 @@ def group_indices(indices: list[tuple[int, int]]) -> None:
7171 physicals = list [Tensor ]()
7272 indices_to_n_pdims = dict [int , int ]()
7373 for t , indices in args :
74- assert isinstance (t , StructuredSparseTensor )
74+ assert isinstance (t , SparseLatticedTensor )
7575 physicals .append (t .physical )
7676 for pdims , index in zip (t .v_to_ps , indices ):
7777 if index in indices_to_n_pdims :
@@ -129,13 +129,13 @@ def unique_int(pair: tuple[int, int]) -> int:
129129
130130def prepare_for_elementwise_op (
131131 t1 : Tensor | int | float , t2 : Tensor | int | float
132- ) -> tuple [StructuredSparseTensor , StructuredSparseTensor ]:
132+ ) -> tuple [SparseLatticedTensor , SparseLatticedTensor ]:
133133 """
134134 Prepares two SSTs of the same shape from two args, one of those being a SST, and the other being
135135 a SST, Tensor, int or float.
136136 """
137137
138- assert isinstance (t1 , StructuredSparseTensor ) or isinstance (t2 , StructuredSparseTensor )
138+ assert isinstance (t1 , SparseLatticedTensor ) or isinstance (t2 , SparseLatticedTensor )
139139
140140 if isinstance (t1 , int ) or isinstance (t1 , float ):
141141 t1_ = tensor (t1 , device = t2 .device )
@@ -148,8 +148,8 @@ def prepare_for_elementwise_op(
148148 t2_ = t2
149149
150150 t1_ , t2_ = aten .broadcast_tensors .default ([t1_ , t2_ ])
151- t1_ = to_structured_sparse_tensor (t1_ )
152- t2_ = to_structured_sparse_tensor (t2_ )
151+ t1_ = to_sparse_latticed_tensor (t1_ )
152+ t2_ = to_sparse_latticed_tensor (t2_ )
153153
154154 return t1_ , t2_
155155
@@ -165,46 +165,46 @@ def mul_Tensor(t1: Tensor | int | float, t2: Tensor | int | float) -> Tensor:
165165@impl (aten .div .Tensor )
166166def div_Tensor (t1 : Tensor | int | float , t2 : Tensor | int | float ) -> Tensor :
167167 t1_ , t2_ = prepare_for_elementwise_op (t1 , t2 )
168- t2_ = StructuredSparseTensor (1.0 / t2_ .physical , t2_ .strides )
168+ t2_ = SparseLatticedTensor (1.0 / t2_ .physical , t2_ .strides )
169169 all_dims = list (range (t1_ .ndim ))
170170 return einsum ((t1_ , all_dims ), (t2_ , all_dims ), output = all_dims )
171171
172172
173173@impl (aten .mul .Scalar )
174- def mul_Scalar (t : StructuredSparseTensor , scalar ) -> StructuredSparseTensor :
174+ def mul_Scalar (t : SparseLatticedTensor , scalar ) -> SparseLatticedTensor :
175175 # TODO: maybe it could be that scalar is a scalar SST and t is a normal tensor. Need to check
176176 # that
177177
178- assert isinstance (t , StructuredSparseTensor )
178+ assert isinstance (t , SparseLatticedTensor )
179179 new_physical = aten .mul .Scalar (t .physical , scalar )
180- return StructuredSparseTensor (new_physical , t .strides )
180+ return SparseLatticedTensor (new_physical , t .strides )
181181
182182
183183@impl (aten .add .Tensor )
184184def add_Tensor (
185185 t1 : Tensor | int | float , t2 : Tensor | int | float , alpha : Tensor | float = 1.0
186- ) -> StructuredSparseTensor :
186+ ) -> SparseLatticedTensor :
187187 t1_ , t2_ = prepare_for_elementwise_op (t1 , t2 )
188188
189189 if torch .equal (t1_ .strides , t2_ .strides ):
190190 new_physical = t1_ .physical + t2_ .physical * alpha
191- return StructuredSparseTensor (new_physical , t1_ .strides )
191+ return SparseLatticedTensor (new_physical , t1_ .strides )
192192 else :
193193 raise NotImplementedError ()
194194
195195
196196@impl (aten .bmm .default )
197197def bmm_default (mat1 : Tensor , mat2 : Tensor ) -> Tensor :
198- assert isinstance (mat1 , StructuredSparseTensor ) or isinstance (mat2 , StructuredSparseTensor )
198+ assert isinstance (mat1 , SparseLatticedTensor ) or isinstance (mat2 , SparseLatticedTensor )
199199 assert (
200200 mat1 .ndim == 3
201201 and mat2 .ndim == 3
202202 and mat1 .shape [0 ] == mat2 .shape [0 ]
203203 and mat1 .shape [2 ] == mat2 .shape [1 ]
204204 )
205205
206- mat1_ = to_structured_sparse_tensor (mat1 )
207- mat2_ = to_structured_sparse_tensor (mat2 )
206+ mat1_ = to_sparse_latticed_tensor (mat1 )
207+ mat2_ = to_sparse_latticed_tensor (mat2 )
208208
209209 # TODO: Verify that the dimension `0` of mat1_ and mat2_ have the same physical dimension sizes
210210 # decompositions. If not, can reshape to common decomposition?
@@ -213,32 +213,32 @@ def bmm_default(mat1: Tensor, mat2: Tensor) -> Tensor:
213213
214214@impl (aten .mm .default )
215215def mm_default (mat1 : Tensor , mat2 : Tensor ) -> Tensor :
216- assert isinstance (mat1 , StructuredSparseTensor ) or isinstance (mat2 , StructuredSparseTensor )
216+ assert isinstance (mat1 , SparseLatticedTensor ) or isinstance (mat2 , SparseLatticedTensor )
217217 assert mat1 .ndim == 2 and mat2 .ndim == 2 and mat1 .shape [1 ] == mat2 .shape [0 ]
218218
219- mat1_ = to_structured_sparse_tensor (mat1 )
220- mat2_ = to_structured_sparse_tensor (mat2 )
219+ mat1_ = to_sparse_latticed_tensor (mat1 )
220+ mat2_ = to_sparse_latticed_tensor (mat2 )
221221
222222 return einsum ((mat1_ , [0 , 1 ]), (mat2_ , [1 , 2 ]), output = [0 , 2 ])
223223
224224
225225@impl (aten .mean .default )
226- def mean_default (t : StructuredSparseTensor ) -> Tensor :
227- assert isinstance (t , StructuredSparseTensor )
226+ def mean_default (t : SparseLatticedTensor ) -> Tensor :
227+ assert isinstance (t , SparseLatticedTensor )
228228 return aten .sum .default (t .physical ) / t .numel ()
229229
230230
231231@impl (aten .sum .default )
232- def sum_default (t : StructuredSparseTensor ) -> Tensor :
233- assert isinstance (t , StructuredSparseTensor )
232+ def sum_default (t : SparseLatticedTensor ) -> Tensor :
233+ assert isinstance (t , SparseLatticedTensor )
234234 return aten .sum .default (t .physical )
235235
236236
237237@impl (aten .sum .dim_IntList )
238238def sum_dim_IntList (
239- t : StructuredSparseTensor , dim : list [int ], keepdim : bool = False , dtype = None
239+ t : SparseLatticedTensor , dim : list [int ], keepdim : bool = False , dtype = None
240240) -> Tensor :
241- assert isinstance (t , StructuredSparseTensor )
241+ assert isinstance (t , SparseLatticedTensor )
242242
243243 if dtype :
244244 raise NotImplementedError ()
0 commit comments