Skip to content

Commit 63549ca

Browse files
committed
Rename stride to basis
1 parent 131fbb4 commit 63549ca

8 files changed

Lines changed: 143 additions & 151 deletions

File tree

src/torchjd/autogram/_engine.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -176,8 +176,8 @@ def differentiation(_grad_output: Tensor) -> tuple[Tensor, ...]:
176176

177177
output_dims = list(range(output.ndim))
178178
identity = torch.eye(output.ndim, dtype=torch.int64)
179-
strides = torch.concatenate([identity, identity], dim=0)
180-
jac_output = make_sst(torch.ones_like(output), strides)
179+
basis = torch.concatenate([identity, identity], dim=0)
180+
jac_output = make_sst(torch.ones_like(output), basis)
181181

182182
vmapped_diff = differentiation
183183
for _ in output_dims:

src/torchjd/sparse/_aten_function_overrides/backward.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ def threshold_backward_default(
1010
) -> SparseLatticedTensor:
1111
new_physical = aten.threshold_backward.default(grad_output.physical, self, threshold)
1212

13-
return SparseLatticedTensor(new_physical, grad_output.strides)
13+
return SparseLatticedTensor(new_physical, grad_output.basis)
1414

1515

1616
@impl(aten.hardtanh_backward.default)
@@ -24,7 +24,7 @@ def hardtanh_backward_default(
2424
raise NotImplementedError()
2525

2626
new_physical = aten.hardtanh_backward.default(grad_output.physical, self, min_val, max_val)
27-
return SparseLatticedTensor(new_physical, grad_output.strides)
27+
return SparseLatticedTensor(new_physical, grad_output.basis)
2828

2929

3030
@impl(aten.hardswish_backward.default)
@@ -33,4 +33,4 @@ def hardswish_backward_default(grad_output: SparseLatticedTensor, self: Tensor):
3333
raise NotImplementedError()
3434

3535
new_physical = aten.hardswish_backward.default(grad_output.physical, self)
36-
return SparseLatticedTensor(new_physical, grad_output.strides)
36+
return SparseLatticedTensor(new_physical, grad_output.basis)

src/torchjd/sparse/_aten_function_overrides/einsum.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ def mul_Tensor(t1: Tensor | int | float, t2: Tensor | int | float) -> Tensor:
165165
@impl(aten.div.Tensor)
166166
def div_Tensor(t1: Tensor | int | float, t2: Tensor | int | float) -> Tensor:
167167
t1_, t2_ = prepare_for_elementwise_op(t1, t2)
168-
t2_ = SparseLatticedTensor(1.0 / t2_.physical, t2_.strides)
168+
t2_ = SparseLatticedTensor(1.0 / t2_.physical, t2_.basis)
169169
all_dims = list(range(t1_.ndim))
170170
return einsum((t1_, all_dims), (t2_, all_dims), output=all_dims)
171171

@@ -177,7 +177,7 @@ def mul_Scalar(t: SparseLatticedTensor, scalar) -> SparseLatticedTensor:
177177

178178
assert isinstance(t, SparseLatticedTensor)
179179
new_physical = aten.mul.Scalar(t.physical, scalar)
180-
return SparseLatticedTensor(new_physical, t.strides)
180+
return SparseLatticedTensor(new_physical, t.basis)
181181

182182

183183
@impl(aten.add.Tensor)
@@ -186,9 +186,9 @@ def add_Tensor(
186186
) -> SparseLatticedTensor:
187187
t1_, t2_ = prepare_for_elementwise_op(t1, t2)
188188

189-
if torch.equal(t1_.strides, t2_.strides):
189+
if torch.equal(t1_.basis, t2_.basis):
190190
new_physical = t1_.physical + t2_.physical * alpha
191-
return SparseLatticedTensor(new_physical, t1_.strides)
191+
return SparseLatticedTensor(new_physical, t1_.basis)
192192
else:
193193
raise NotImplementedError()
194194

src/torchjd/sparse/_aten_function_overrides/pointwise.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def _override_pointwise(op):
7171
@impl(op)
7272
def func_(t: SparseLatticedTensor) -> SparseLatticedTensor:
7373
assert isinstance(t, SparseLatticedTensor)
74-
return SparseLatticedTensor(op(t.physical), t.strides)
74+
return SparseLatticedTensor(op(t.physical), t.basis)
7575

7676
return func_
7777

@@ -100,7 +100,7 @@ def pow_Tensor_Scalar(t: SparseLatticedTensor, exponent: float) -> SparseLattice
100100
return aten.pow.Tensor_Scalar(t.to_dense(), exponent)
101101

102102
new_physical = aten.pow.Tensor_Scalar(t.physical, exponent)
103-
return SparseLatticedTensor(new_physical, t.strides)
103+
return SparseLatticedTensor(new_physical, t.basis)
104104

105105

106106
# Somehow there's no pow_.Tensor_Scalar and pow_.Scalar takes tensor and scalar.
@@ -122,4 +122,4 @@ def div_Scalar(t: SparseLatticedTensor, divisor: float) -> SparseLatticedTensor:
122122
assert isinstance(t, SparseLatticedTensor)
123123

124124
new_physical = aten.div.Scalar(t.physical, divisor)
125-
return SparseLatticedTensor(new_physical, t.strides)
125+
return SparseLatticedTensor(new_physical, t.basis)

src/torchjd/sparse/_aten_function_overrides/shape.py

Lines changed: 42 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ def view_default(t: SparseLatticedTensor, shape: list[int]) -> Tensor:
2828
c.T = [prod(t.shape[1:]), prod(t.shape[2:]), ..., t.shape[-1], 1]
2929
* c' is the same thing but after the reshape, i.e.
3030
c'.T = [prod(shape[1:]), prod(shape[2:]), ..., shape[-1], 1]
31-
* S is the original matrix of strides (t.strides)
32-
* S' is the matrix of strides after reshaping.
31+
* S is the original basis matrix (t.basis)
32+
* S' is the basis matrix after reshaping.
3333
3434
For u, v in Z^m and c in Z, say that u ≡ v (mod c) if u_i ≡ v_i (mod c) for all i.
3535
Note that c'.T S' ≡ S'[-1] (mod shape[-1])
@@ -46,12 +46,12 @@ def view_default(t: SparseLatticedTensor, shape: list[int]) -> Tensor:
4646
if prod(shape) != t.numel():
4747
raise ValueError(f"shape '{shape}' is invalid for input of size {t.numel()}")
4848

49-
S = t.strides
49+
S = t.basis
5050
vshape = list(t.shape)
5151
c = _reverse_cumulative_product(vshape)
5252
c_prime = _reverse_cumulative_product(shape)
53-
new_strides = ((c @ S).unsqueeze(0) // c_prime.unsqueeze(1)) % tensor(shape).unsqueeze(1)
54-
return to_most_efficient_tensor(t.physical, new_strides)
53+
new_basis = ((c @ S).unsqueeze(0) // c_prime.unsqueeze(1)) % tensor(shape).unsqueeze(1)
54+
return to_most_efficient_tensor(t.physical, new_basis)
5555

5656

5757
def _reverse_cumulative_product(values: list[int]) -> Tensor:
@@ -70,7 +70,7 @@ def infer_shape(shape: list[int], numel: int) -> list[int]:
7070

7171

7272
def unsquash_pdim(
73-
physical: Tensor, strides: Tensor, pdim: int, new_pdim_shape: list[int]
73+
physical: Tensor, basis: Tensor, pdim: int, new_pdim_shape: list[int]
7474
) -> tuple[Tensor, Tensor]:
7575
"""
7676
EXAMPLE:
@@ -80,7 +80,7 @@ def unsquash_pdim(
8080
[7, 8, 9, 10, 11, 12],
8181
[13, 14, 15, 16, 17, 18],
8282
]
83-
strides = [
83+
basis = [
8484
[1, 1],
8585
[0, 2],
8686
]
@@ -99,7 +99,7 @@ def unsquash_pdim(
9999
[16, 17, 18],
100100
]]
101101
102-
new_strides = [
102+
new_basis = [
103103
[1, 3, 1],
104104
[0, 6, 2]
105105
"""
@@ -110,18 +110,18 @@ def unsquash_pdim(
110110
new_shape = old_shape[:pdim] + new_pdim_shape + old_shape[pdim + 1 :]
111111
new_physical = physical.reshape(new_shape)
112112

113-
stride_multipliers = tensor([prod(new_pdim_shape[i + 1 :]) for i in range(len(new_pdim_shape))])
113+
multipliers = tensor([prod(new_pdim_shape[i + 1 :]) for i in range(len(new_pdim_shape))])
114114

115-
new_strides = torch.concat(
115+
new_basis = torch.concat(
116116
[
117-
strides[:, :pdim],
118-
torch.outer(strides[:, pdim], stride_multipliers),
119-
strides[:, pdim + 1 :],
117+
basis[:, :pdim],
118+
torch.outer(basis[:, pdim], multipliers),
119+
basis[:, pdim + 1 :],
120120
],
121121
dim=1,
122122
)
123123

124-
return new_physical, new_strides
124+
return new_physical, new_basis
125125

126126

127127
@impl(aten._unsafe_view.default)
@@ -139,10 +139,10 @@ def unsqueeze_default(t: SparseLatticedTensor, dim: int) -> SparseLatticedTensor
139139
if dim < 0:
140140
dim = t.ndim + dim + 1
141141

142-
new_strides = torch.concatenate(
143-
[t.strides[:dim], torch.zeros(1, t.strides.shape[1], dtype=torch.int64), t.strides[dim:]]
142+
new_basis = torch.concatenate(
143+
[t.basis[:dim], torch.zeros(1, t.basis.shape[1], dtype=torch.int64), t.basis[dim:]]
144144
)
145-
return SparseLatticedTensor(t.physical, new_strides)
145+
return SparseLatticedTensor(t.physical, new_basis)
146146

147147

148148
@impl(aten.squeeze.dims)
@@ -157,14 +157,14 @@ def squeeze_dims(t: SparseLatticedTensor, dims: list[int] | int | None) -> Tenso
157157
excluded = set(dims)
158158

159159
is_row_kept = [i not in excluded for i in range(t.ndim)]
160-
new_strides = t.strides[is_row_kept]
161-
return to_most_efficient_tensor(t.physical, new_strides)
160+
new_basis = t.basis[is_row_kept]
161+
return to_most_efficient_tensor(t.physical, new_basis)
162162

163163

164164
@impl(aten.permute.default)
165165
def permute_default(t: SparseLatticedTensor, dims: list[int]) -> SparseLatticedTensor:
166-
new_strides = t.strides[torch.tensor(dims)]
167-
return SparseLatticedTensor(t.physical, new_strides)
166+
new_basis = t.basis[torch.tensor(dims)]
167+
return SparseLatticedTensor(t.physical, new_basis)
168168

169169

170170
@impl(aten.cat.default)
@@ -175,11 +175,11 @@ def cat_default(tensors: list[Tensor], dim: int) -> Tensor:
175175

176176
tensors_ = [cast(SparseLatticedTensor, t) for t in tensors]
177177
ref_tensor = tensors_[0]
178-
ref_strides = ref_tensor.strides
179-
if any(not torch.equal(t.strides, ref_strides) for t in tensors_[1:]):
178+
ref_basis = ref_tensor.basis
179+
if any(not torch.equal(t.basis, ref_basis) for t in tensors_[1:]):
180180
raise NotImplementedError(
181181
"Override for aten.cat.default does not support SSTs that do not all have the same "
182-
f"strides. Found the following tensors:\n{[t.debug_info() for t in tensors_]} and the "
182+
f"basis. Found the following tensors:\n{[t.debug_info() for t in tensors_]} and the "
183183
f"following dim: {dim}."
184184
)
185185

@@ -189,8 +189,8 @@ def cat_default(tensors: list[Tensor], dim: int) -> Tensor:
189189

190190
ref_virtual_dim_size = ref_tensor.shape[dim]
191191
indices = torch.argwhere(
192-
torch.eq(ref_strides[dim] * tensor(ref_tensor.physical.shape), ref_virtual_dim_size)
193-
& torch.eq(ref_strides.sum(dim=0) * tensor(ref_tensor.physical.shape), ref_virtual_dim_size)
192+
torch.eq(ref_basis[dim] * tensor(ref_tensor.physical.shape), ref_virtual_dim_size)
193+
& torch.eq(ref_basis.sum(dim=0) * tensor(ref_tensor.physical.shape), ref_virtual_dim_size)
194194
)
195195
assert len(indices) <= 1
196196

@@ -200,18 +200,18 @@ def cat_default(tensors: list[Tensor], dim: int) -> Tensor:
200200

201201
pdim = ref_tensor.physical.ndim
202202
physicals = [t.physical.unsqueeze(-1) for t in tensors_]
203-
new_stride_column = torch.zeros(ref_tensor.ndim, 1, dtype=torch.int64)
204-
new_stride_column[dim, 0] = ref_virtual_dim_size
205-
new_strides = torch.concatenate([ref_tensor.strides, new_stride_column], dim=1)
203+
new_basis_vector = torch.zeros(ref_tensor.ndim, 1, dtype=torch.int64)
204+
new_basis_vector[dim, 0] = ref_virtual_dim_size
205+
new_basis = torch.concatenate([ref_tensor.basis, new_basis_vector], dim=1)
206206
else:
207207
# Such a physical dimension already exists. Note that an alternative implementation would be
208208
# to simply always add the physical dimension, and squash it if it ends up being not needed.
209209
physicals = [t.physical for t in tensors_]
210210
pdim = cast(int, indices[0, 0].item())
211-
new_strides = ref_tensor.strides
211+
new_basis = ref_tensor.basis
212212

213213
new_physical = aten.cat.default(physicals, dim=pdim)
214-
return SparseLatticedTensor(new_physical, new_strides)
214+
return SparseLatticedTensor(new_physical, new_basis)
215215

216216

217217
@impl(aten.expand.default)
@@ -227,26 +227,26 @@ def expand_default(t: SparseLatticedTensor, sizes: list[int]) -> SparseLatticedT
227227

228228
# Try to expand each dimension to its new size
229229
new_physical = t.physical
230-
new_strides = t.strides
231-
for d, (vstride, orig_size, new_size) in enumerate(zip(t.strides, t.shape, sizes, strict=True)):
232-
if vstride.sum() > 0 and orig_size != new_size and new_size != -1:
230+
new_basis = t.basis
231+
for d, (v, orig_size, new_size) in enumerate(zip(t.basis, t.shape, sizes, strict=True)):
232+
if v.sum() > 0 and orig_size != new_size and new_size != -1:
233233
raise ValueError(
234234
f"Cannot expand dim {d} of size != 1. Found size {orig_size} and target size "
235235
f"{new_size}."
236236
)
237237

238-
if vstride.sum() == 0 and new_size != 1 and new_size != -1:
238+
if v.sum() == 0 and new_size != 1 and new_size != -1:
239239
# Add a dimension of size new_size at the end of the physical tensor.
240240
new_physical_shape = list(new_physical.shape) + [new_size]
241241
new_physical = new_physical.unsqueeze(-1).expand(new_physical_shape)
242242

243-
# Make this new physical dimension have a stride of 1 at virtual dimension d and 0 at
244-
# every other virtual dimension
245-
new_stride_column = torch.zeros(t.ndim, 1, dtype=torch.int64)
246-
new_stride_column[d, 0] = 1
247-
new_strides = torch.cat([new_strides, new_stride_column], dim=1)
243+
# Make the basis vector of this new physical dimension be 1 at virtual dimension d and 0
244+
# at every other virtual dimension
245+
new_basis_vector = torch.zeros(t.ndim, 1, dtype=torch.int64)
246+
new_basis_vector[d, 0] = 1
247+
new_basis = torch.cat([new_basis, new_basis_vector], dim=1)
248248

249-
return SparseLatticedTensor(new_physical, new_strides)
249+
return SparseLatticedTensor(new_physical, new_basis)
250250

251251

252252
@impl(aten.broadcast_tensors.default)
@@ -279,7 +279,7 @@ def broadcast_tensors_default(tensors: list[Tensor]) -> tuple[Tensor, Tensor]:
279279
@impl(aten.transpose.int)
280280
def transpose_int(t: SparseLatticedTensor, dim0: int, dim1: int) -> SparseLatticedTensor:
281281
assert isinstance(t, SparseLatticedTensor)
282-
return SparseLatticedTensor(t.physical, _swap_rows(t.strides, dim0, dim1))
282+
return SparseLatticedTensor(t.physical, _swap_rows(t.basis, dim0, dim1))
283283

284284

285285
def _swap_rows(matrix: Tensor, c0: int, c1: int) -> Tensor:

src/torchjd/sparse/_coalesce.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,17 @@
22
from torch import Tensor
33

44

5-
def fix_zero_stride_columns(physical: Tensor, strides: Tensor) -> tuple[Tensor, Tensor]:
5+
def fix_zero_basis_vectors(physical: Tensor, basis: Tensor) -> tuple[Tensor, Tensor]:
66
"""
7-
Remove columns of strides that are all 0 and sum the corresponding elements in the physical
8-
tensor.
7+
Remove basis vectors that are all 0 and sum the corresponding elements in the physical tensor.
98
"""
109

11-
are_columns_zero = (strides == 0).all(dim=0)
10+
are_vectors_zero = (basis == 0).all(dim=0)
1211

13-
if not are_columns_zero.any():
14-
return physical, strides
12+
if not are_vectors_zero.any():
13+
return physical, basis
1514

16-
zero_column_indices = torch.arange(len(are_columns_zero))[are_columns_zero].tolist()
15+
zero_column_indices = torch.arange(len(are_vectors_zero))[are_vectors_zero].tolist()
1716
physical = physical.sum(dim=zero_column_indices)
18-
strides = strides[:, ~are_columns_zero]
19-
return physical, strides
17+
basis = basis[:, ~are_vectors_zero]
18+
return physical, basis

0 commit comments

Comments
 (0)