Skip to content

Commit 131fbb4

Browse files
committed
Rename SST to SparseLatticedTensor
1 parent 35522f7 commit 131fbb4

7 files changed

Lines changed: 108 additions & 108 deletions

File tree

src/torchjd/sparse/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
# Need to import this to execute the code inside and thus to override the functions
22
from . import _aten_function_overrides
3-
from ._structured_sparse_tensor import StructuredSparseTensor, make_sst
3+
from ._sparse_latticed_tensor import SparseLatticedTensor, make_sst
Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,36 @@
11
from torch import Tensor
22
from torch.ops import aten # type: ignore
33

4-
from torchjd.sparse._structured_sparse_tensor import StructuredSparseTensor, impl
4+
from torchjd.sparse._sparse_latticed_tensor import SparseLatticedTensor, impl
55

66

77
@impl(aten.threshold_backward.default)
88
def threshold_backward_default(
9-
grad_output: StructuredSparseTensor, self: Tensor, threshold
10-
) -> StructuredSparseTensor:
9+
grad_output: SparseLatticedTensor, self: Tensor, threshold
10+
) -> SparseLatticedTensor:
1111
new_physical = aten.threshold_backward.default(grad_output.physical, self, threshold)
1212

13-
return StructuredSparseTensor(new_physical, grad_output.strides)
13+
return SparseLatticedTensor(new_physical, grad_output.strides)
1414

1515

1616
@impl(aten.hardtanh_backward.default)
1717
def hardtanh_backward_default(
18-
grad_output: StructuredSparseTensor,
18+
grad_output: SparseLatticedTensor,
1919
self: Tensor,
2020
min_val: Tensor | int | float,
2121
max_val: Tensor | int | float,
22-
) -> StructuredSparseTensor:
23-
if isinstance(self, StructuredSparseTensor):
22+
) -> SparseLatticedTensor:
23+
if isinstance(self, SparseLatticedTensor):
2424
raise NotImplementedError()
2525

2626
new_physical = aten.hardtanh_backward.default(grad_output.physical, self, min_val, max_val)
27-
return StructuredSparseTensor(new_physical, grad_output.strides)
27+
return SparseLatticedTensor(new_physical, grad_output.strides)
2828

2929

3030
@impl(aten.hardswish_backward.default)
31-
def hardswish_backward_default(grad_output: StructuredSparseTensor, self: Tensor):
32-
if isinstance(self, StructuredSparseTensor):
31+
def hardswish_backward_default(grad_output: SparseLatticedTensor, self: Tensor):
32+
if isinstance(self, SparseLatticedTensor):
3333
raise NotImplementedError()
3434

3535
new_physical = aten.hardswish_backward.default(grad_output.physical, self)
36-
return StructuredSparseTensor(new_physical, grad_output.strides)
36+
return SparseLatticedTensor(new_physical, grad_output.strides)

src/torchjd/sparse/_aten_function_overrides/einsum.py

Lines changed: 29 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,15 @@
22
from torch import Tensor, tensor
33
from torch.ops import aten # type: ignore
44

5-
from torchjd.sparse._structured_sparse_tensor import (
6-
StructuredSparseTensor,
5+
from torchjd.sparse._sparse_latticed_tensor import (
6+
SparseLatticedTensor,
77
impl,
88
to_most_efficient_tensor,
9-
to_structured_sparse_tensor,
9+
to_sparse_latticed_tensor,
1010
)
1111

1212

13-
def einsum(*args: tuple[StructuredSparseTensor, list[int]], output: list[int]) -> Tensor:
13+
def einsum(*args: tuple[SparseLatticedTensor, list[int]], output: list[int]) -> Tensor:
1414
raise NotImplementedError()
1515

1616
# First part of the algorithm, determine how to cluster physical indices as well as the common
@@ -39,13 +39,13 @@ def einsum(*args: tuple[StructuredSparseTensor, list[int]], output: list[int]) -
3939
# [p_1, ..., p_k], then we have to create fresh sub-indices for each dimension.
4040
# For this reason, an index is decomposed into sub-indices that are then independently
4141
# clustered.
42-
# So if an index i in args for some StructuredSparseTensor corresponds to a v_to_ps [j, k, l],
42+
# So if an index i in args for some SparseLatticedTensor corresponds to a v_to_ps [j, k, l],
4343
# We will consider three indices (i, 0), (i, 1) and (i, 2).
4444
# If furthermore [k] correspond to the v_to_ps of some other tensor with index j, then
4545
# (i, 1) and (j, 0) will be clustered together (and end up being mapped to the same indice in
4646
# the resulting einsum).
4747
# Note that this is a problem if two virtual dimensions (from possibly different
48-
# StructuredSparseTensors) have the same size but not the same decomposition into physical
48+
# SparseLatticedTensors) have the same size but not the same decomposition into physical
4949
# dimension sizes. For now lets leave the responsibility to care about that in the calling
5050
# functions, if we can factor code later on we will.
5151

@@ -71,7 +71,7 @@ def group_indices(indices: list[tuple[int, int]]) -> None:
7171
physicals = list[Tensor]()
7272
indices_to_n_pdims = dict[int, int]()
7373
for t, indices in args:
74-
assert isinstance(t, StructuredSparseTensor)
74+
assert isinstance(t, SparseLatticedTensor)
7575
physicals.append(t.physical)
7676
for pdims, index in zip(t.v_to_ps, indices):
7777
if index in indices_to_n_pdims:
@@ -129,13 +129,13 @@ def unique_int(pair: tuple[int, int]) -> int:
129129

130130
def prepare_for_elementwise_op(
131131
t1: Tensor | int | float, t2: Tensor | int | float
132-
) -> tuple[StructuredSparseTensor, StructuredSparseTensor]:
132+
) -> tuple[SparseLatticedTensor, SparseLatticedTensor]:
133133
"""
134134
Prepares two SSTs of the same shape from two args, one of those being a SST, and the other being
135135
a SST, Tensor, int or float.
136136
"""
137137

138-
assert isinstance(t1, StructuredSparseTensor) or isinstance(t2, StructuredSparseTensor)
138+
assert isinstance(t1, SparseLatticedTensor) or isinstance(t2, SparseLatticedTensor)
139139

140140
if isinstance(t1, int) or isinstance(t1, float):
141141
t1_ = tensor(t1, device=t2.device)
@@ -148,8 +148,8 @@ def prepare_for_elementwise_op(
148148
t2_ = t2
149149

150150
t1_, t2_ = aten.broadcast_tensors.default([t1_, t2_])
151-
t1_ = to_structured_sparse_tensor(t1_)
152-
t2_ = to_structured_sparse_tensor(t2_)
151+
t1_ = to_sparse_latticed_tensor(t1_)
152+
t2_ = to_sparse_latticed_tensor(t2_)
153153

154154
return t1_, t2_
155155

@@ -165,46 +165,46 @@ def mul_Tensor(t1: Tensor | int | float, t2: Tensor | int | float) -> Tensor:
165165
@impl(aten.div.Tensor)
166166
def div_Tensor(t1: Tensor | int | float, t2: Tensor | int | float) -> Tensor:
167167
t1_, t2_ = prepare_for_elementwise_op(t1, t2)
168-
t2_ = StructuredSparseTensor(1.0 / t2_.physical, t2_.strides)
168+
t2_ = SparseLatticedTensor(1.0 / t2_.physical, t2_.strides)
169169
all_dims = list(range(t1_.ndim))
170170
return einsum((t1_, all_dims), (t2_, all_dims), output=all_dims)
171171

172172

173173
@impl(aten.mul.Scalar)
174-
def mul_Scalar(t: StructuredSparseTensor, scalar) -> StructuredSparseTensor:
174+
def mul_Scalar(t: SparseLatticedTensor, scalar) -> SparseLatticedTensor:
175175
# TODO: maybe it could be that scalar is a scalar SST and t is a normal tensor. Need to check
176176
# that
177177

178-
assert isinstance(t, StructuredSparseTensor)
178+
assert isinstance(t, SparseLatticedTensor)
179179
new_physical = aten.mul.Scalar(t.physical, scalar)
180-
return StructuredSparseTensor(new_physical, t.strides)
180+
return SparseLatticedTensor(new_physical, t.strides)
181181

182182

183183
@impl(aten.add.Tensor)
184184
def add_Tensor(
185185
t1: Tensor | int | float, t2: Tensor | int | float, alpha: Tensor | float = 1.0
186-
) -> StructuredSparseTensor:
186+
) -> SparseLatticedTensor:
187187
t1_, t2_ = prepare_for_elementwise_op(t1, t2)
188188

189189
if torch.equal(t1_.strides, t2_.strides):
190190
new_physical = t1_.physical + t2_.physical * alpha
191-
return StructuredSparseTensor(new_physical, t1_.strides)
191+
return SparseLatticedTensor(new_physical, t1_.strides)
192192
else:
193193
raise NotImplementedError()
194194

195195

196196
@impl(aten.bmm.default)
197197
def bmm_default(mat1: Tensor, mat2: Tensor) -> Tensor:
198-
assert isinstance(mat1, StructuredSparseTensor) or isinstance(mat2, StructuredSparseTensor)
198+
assert isinstance(mat1, SparseLatticedTensor) or isinstance(mat2, SparseLatticedTensor)
199199
assert (
200200
mat1.ndim == 3
201201
and mat2.ndim == 3
202202
and mat1.shape[0] == mat2.shape[0]
203203
and mat1.shape[2] == mat2.shape[1]
204204
)
205205

206-
mat1_ = to_structured_sparse_tensor(mat1)
207-
mat2_ = to_structured_sparse_tensor(mat2)
206+
mat1_ = to_sparse_latticed_tensor(mat1)
207+
mat2_ = to_sparse_latticed_tensor(mat2)
208208

209209
# TODO: Verify that the dimension `0` of mat1_ and mat2_ have the same physical dimension sizes
210210
# decompositions. If not, can reshape to common decomposition?
@@ -213,32 +213,32 @@ def bmm_default(mat1: Tensor, mat2: Tensor) -> Tensor:
213213

214214
@impl(aten.mm.default)
215215
def mm_default(mat1: Tensor, mat2: Tensor) -> Tensor:
216-
assert isinstance(mat1, StructuredSparseTensor) or isinstance(mat2, StructuredSparseTensor)
216+
assert isinstance(mat1, SparseLatticedTensor) or isinstance(mat2, SparseLatticedTensor)
217217
assert mat1.ndim == 2 and mat2.ndim == 2 and mat1.shape[1] == mat2.shape[0]
218218

219-
mat1_ = to_structured_sparse_tensor(mat1)
220-
mat2_ = to_structured_sparse_tensor(mat2)
219+
mat1_ = to_sparse_latticed_tensor(mat1)
220+
mat2_ = to_sparse_latticed_tensor(mat2)
221221

222222
return einsum((mat1_, [0, 1]), (mat2_, [1, 2]), output=[0, 2])
223223

224224

225225
@impl(aten.mean.default)
226-
def mean_default(t: StructuredSparseTensor) -> Tensor:
227-
assert isinstance(t, StructuredSparseTensor)
226+
def mean_default(t: SparseLatticedTensor) -> Tensor:
227+
assert isinstance(t, SparseLatticedTensor)
228228
return aten.sum.default(t.physical) / t.numel()
229229

230230

231231
@impl(aten.sum.default)
232-
def sum_default(t: StructuredSparseTensor) -> Tensor:
233-
assert isinstance(t, StructuredSparseTensor)
232+
def sum_default(t: SparseLatticedTensor) -> Tensor:
233+
assert isinstance(t, SparseLatticedTensor)
234234
return aten.sum.default(t.physical)
235235

236236

237237
@impl(aten.sum.dim_IntList)
238238
def sum_dim_IntList(
239-
t: StructuredSparseTensor, dim: list[int], keepdim: bool = False, dtype=None
239+
t: SparseLatticedTensor, dim: list[int], keepdim: bool = False, dtype=None
240240
) -> Tensor:
241-
assert isinstance(t, StructuredSparseTensor)
241+
assert isinstance(t, SparseLatticedTensor)
242242

243243
if dtype:
244244
raise NotImplementedError()

src/torchjd/sparse/_aten_function_overrides/pointwise.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from torch.ops import aten # type: ignore
22

3-
from torchjd.sparse._structured_sparse_tensor import StructuredSparseTensor, impl
3+
from torchjd.sparse._sparse_latticed_tensor import SparseLatticedTensor, impl
44

55
# pointwise functions applied to one Tensor with `0.0 → 0`
66
_POINTWISE_FUNCTIONS = [
@@ -69,17 +69,17 @@
6969

7070
def _override_pointwise(op):
7171
@impl(op)
72-
def func_(t: StructuredSparseTensor) -> StructuredSparseTensor:
73-
assert isinstance(t, StructuredSparseTensor)
74-
return StructuredSparseTensor(op(t.physical), t.strides)
72+
def func_(t: SparseLatticedTensor) -> SparseLatticedTensor:
73+
assert isinstance(t, SparseLatticedTensor)
74+
return SparseLatticedTensor(op(t.physical), t.strides)
7575

7676
return func_
7777

7878

7979
def _override_inplace_pointwise(op):
8080
@impl(op)
81-
def func_(t: StructuredSparseTensor) -> StructuredSparseTensor:
82-
assert isinstance(t, StructuredSparseTensor)
81+
def func_(t: SparseLatticedTensor) -> SparseLatticedTensor:
82+
assert isinstance(t, SparseLatticedTensor)
8383
op(t.physical)
8484
return t
8585

@@ -92,21 +92,21 @@ def func_(t: StructuredSparseTensor) -> StructuredSparseTensor:
9292

9393

9494
@impl(aten.pow.Tensor_Scalar)
95-
def pow_Tensor_Scalar(t: StructuredSparseTensor, exponent: float) -> StructuredSparseTensor:
96-
assert isinstance(t, StructuredSparseTensor)
95+
def pow_Tensor_Scalar(t: SparseLatticedTensor, exponent: float) -> SparseLatticedTensor:
96+
assert isinstance(t, SparseLatticedTensor)
9797

9898
if exponent <= 0.0:
9999
# Need to densify because we don't have pow(0.0, exponent) = 0.0
100100
return aten.pow.Tensor_Scalar(t.to_dense(), exponent)
101101

102102
new_physical = aten.pow.Tensor_Scalar(t.physical, exponent)
103-
return StructuredSparseTensor(new_physical, t.strides)
103+
return SparseLatticedTensor(new_physical, t.strides)
104104

105105

106106
# Somehow there's no pow_.Tensor_Scalar and pow_.Scalar takes tensor and scalar.
107107
@impl(aten.pow_.Scalar)
108-
def pow__Scalar(t: StructuredSparseTensor, exponent: float) -> StructuredSparseTensor:
109-
assert isinstance(t, StructuredSparseTensor)
108+
def pow__Scalar(t: SparseLatticedTensor, exponent: float) -> SparseLatticedTensor:
109+
assert isinstance(t, SparseLatticedTensor)
110110

111111
if exponent <= 0.0:
112112
# Need to densify because we don't have pow(0.0, exponent) = 0.0
@@ -118,8 +118,8 @@ def pow__Scalar(t: StructuredSparseTensor, exponent: float) -> StructuredSparseT
118118

119119

120120
@impl(aten.div.Scalar)
121-
def div_Scalar(t: StructuredSparseTensor, divisor: float) -> StructuredSparseTensor:
122-
assert isinstance(t, StructuredSparseTensor)
121+
def div_Scalar(t: SparseLatticedTensor, divisor: float) -> SparseLatticedTensor:
122+
assert isinstance(t, SparseLatticedTensor)
123123

124124
new_physical = aten.div.Scalar(t.physical, divisor)
125-
return StructuredSparseTensor(new_physical, t.strides)
125+
return SparseLatticedTensor(new_physical, t.strides)

0 commit comments

Comments
 (0)