Skip to content

Commit 5521ad4

Browse files
PierreQuintonValerianRey
authored andcommitted
Add StructuredSparseTensor
1 parent 78a60b3 commit 5521ad4

11 files changed

Lines changed: 1441 additions & 13 deletions

File tree

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
# Jupyter notebooks
2+
*.ipynb
3+
14
# uv
25
uv.lock
36

src/torchjd/autogram/_engine.py

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from torch import Tensor, nn, vmap
55
from torch.autograd.graph import get_gradient_edge
66

7+
from torchjd.sparse import make_sst
8+
79
from ._edge_registry import EdgeRegistry
810
from ._gramian_accumulator import GramianAccumulator
911
from ._gramian_computer import GramianComputer, JacobianBasedGramianComputerWithCrossTerms
@@ -173,7 +175,9 @@ def differentiation(_grad_output: Tensor) -> tuple[Tensor, ...]:
173175
)
174176

175177
output_dims = list(range(output.ndim))
176-
jac_output = _make_initial_jac_output(output)
178+
identity = torch.eye(output.ndim, dtype=torch.int64)
179+
strides = torch.concatenate([identity, identity], dim=0)
180+
jac_output = make_sst(torch.ones_like(output), strides)
177181

178182
vmapped_diff = differentiation
179183
for _ in output_dims:
@@ -193,15 +197,3 @@ def differentiation(_grad_output: Tensor) -> tuple[Tensor, ...]:
193197
gramian_computer.reset()
194198

195199
return gramian
196-
197-
198-
def _make_initial_jac_output(output: Tensor) -> Tensor:
199-
if output.ndim == 0:
200-
return torch.ones_like(output)
201-
p_index_ranges = [torch.arange(s, device=output.device) for s in output.shape]
202-
p_indices_grid = torch.meshgrid(*p_index_ranges, indexing="ij")
203-
v_indices_grid = p_indices_grid + p_indices_grid
204-
205-
res = torch.zeros(list(output.shape) * 2, device=output.device, dtype=output.dtype)
206-
res[v_indices_grid] = 1.0
207-
return res

src/torchjd/sparse/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# Need to import this to execute the code inside and thus to override the functions
2+
from . import _aten_function_overrides
3+
from ._structured_sparse_tensor import StructuredSparseTensor, make_sst
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from . import backward, einsum, pointwise, shape
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
from torch import Tensor
2+
from torch.ops import aten # type: ignore
3+
4+
from torchjd.sparse._structured_sparse_tensor import StructuredSparseTensor, impl
5+
6+
7+
@impl(aten.threshold_backward.default)
8+
def threshold_backward_default(
9+
grad_output: StructuredSparseTensor, self: Tensor, threshold
10+
) -> StructuredSparseTensor:
11+
new_physical = aten.threshold_backward.default(grad_output.physical, self, threshold)
12+
13+
return StructuredSparseTensor(new_physical, grad_output.strides)
14+
15+
16+
@impl(aten.hardtanh_backward.default)
17+
def hardtanh_backward_default(
18+
grad_output: StructuredSparseTensor,
19+
self: Tensor,
20+
min_val: Tensor | int | float,
21+
max_val: Tensor | int | float,
22+
) -> StructuredSparseTensor:
23+
if isinstance(self, StructuredSparseTensor):
24+
raise NotImplementedError()
25+
26+
new_physical = aten.hardtanh_backward.default(grad_output.physical, self, min_val, max_val)
27+
return StructuredSparseTensor(new_physical, grad_output.strides)
28+
29+
30+
@impl(aten.hardswish_backward.default)
31+
def hardswish_backward_default(grad_output: StructuredSparseTensor, self: Tensor):
32+
if isinstance(self, StructuredSparseTensor):
33+
raise NotImplementedError()
34+
35+
new_physical = aten.hardswish_backward.default(grad_output.physical, self)
36+
return StructuredSparseTensor(new_physical, grad_output.strides)
Lines changed: 252 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,252 @@
1+
import torch
2+
from torch import Tensor, tensor
3+
from torch.ops import aten # type: ignore
4+
5+
from torchjd.sparse._structured_sparse_tensor import (
6+
StructuredSparseTensor,
7+
impl,
8+
to_most_efficient_tensor,
9+
to_structured_sparse_tensor,
10+
)
11+
12+
13+
def einsum(*args: tuple[StructuredSparseTensor, list[int]], output: list[int]) -> Tensor:
14+
15+
# First part of the algorithm, determine how to cluster physical indices as well as the common
16+
# p_shapes corresponding to matching v_dims. Second part translates to physical einsum.
17+
18+
# get a map from einsum index to (tensor_idx, v_dims)
19+
# get a map from einsum index to merge of strides corresponding to v_dims with that index
20+
# use to_target_physical_strides on each physical and v_to_ps
21+
# cluster pairs of (einsum_index, new_stride) using new_v_to_ps and possibly its corresponding
22+
# p_to_vs
23+
# get unique indices
24+
# map output indices (there can be splits)
25+
# call physical einsum
26+
# build resulting sst
27+
28+
# OVER
29+
30+
# an index in the physical einsum is uniquely characterized by a virtual einsum index and a
31+
# stride corresponding to the physical stride in the virtual one (note that as the virtual shape
32+
# for two virtual index that match should match, then we want to match the strides and reshape
33+
# accordingly).
34+
# We want to cluster such indices whenever several appear in the same p_to_vs
35+
36+
# TODO: Handle ellipsis
37+
# If we have an index v for some virtual dim whose corresponding v_to_ps is a non-trivial list
38+
# [p_1, ..., p_k], then we have to create fresh sub-indices for each dimension.
39+
# For this reason, an index is decomposed into sub-indices that are then independently
40+
# clustered.
41+
# So if an index i in args for some StructuredSparseTensor corresponds to a v_to_ps [j, k, l],
42+
# We will consider three indices (i, 0), (i, 1) and (i, 2).
43+
# If furthermore [k] correspond to the v_to_ps of some other tensor with index j, then
44+
# (i, 1) and (j, 0) will be clustered together (and end up being mapped to the same indice in
45+
# the resulting einsum).
46+
# Note that this is a problem if two virtual dimensions (from possibly different
47+
# StructuredSparseTensors) have the same size but not the same decomposition into physical
48+
# dimension sizes. For now lets leave the responsibility to care about that in the calling
49+
# functions, if we can factor code later on we will.
50+
51+
index_parents = dict[tuple[int, int], tuple[int, int]]()
52+
53+
def get_representative(index: tuple[int, int]) -> tuple[int, int]:
54+
if index not in index_parents:
55+
# If an index is not yet in a cluster, put it in its own.
56+
index_parents[index] = index
57+
current = index_parents[index]
58+
if current != index:
59+
# Compress path to representative
60+
index_parents[index] = get_representative(current)
61+
return index_parents[index]
62+
63+
def group_indices(indices: list[tuple[int, int]]) -> None:
64+
first_representative = get_representative(indices[0])
65+
for i in indices[1:]:
66+
curr_representative = get_representative(i)
67+
index_parents[curr_representative] = first_representative
68+
69+
new_indices_pair = list[list[tuple[int, int]]]()
70+
physicals = list[Tensor]()
71+
indices_to_n_pdims = dict[int, int]()
72+
for t, indices in args:
73+
assert isinstance(t, StructuredSparseTensor)
74+
physicals.append(t.physical)
75+
for pdims, index in zip(t.v_to_ps, indices):
76+
if index in indices_to_n_pdims:
77+
if indices_to_n_pdims[index] != len(pdims):
78+
raise NotImplementedError(
79+
"einsum currently does not support having a different number of physical "
80+
"dimensions corresponding to matching virtual dimensions of different "
81+
f"tensors. Found {[(t.debug_info(), indices) for t, indices in args]}, "
82+
f"output_indices={output}."
83+
)
84+
else:
85+
indices_to_n_pdims[index] = len(pdims)
86+
p_to_vs = ... # p_to_vs_from_v_to_ps(t.v_to_ps)
87+
for indices_ in p_to_vs:
88+
# elements in indices[indices_] map to the same dimension, they should be clustered
89+
# together
90+
group_indices([(indices[i], sub_i) for i, sub_i in indices_])
91+
# record the physical dimensions, index[v] for v in vs will end-up mapping to the same
92+
# final dimension as they were just clustered, so we can take the first, which exists as
93+
# t is a valid SST.
94+
new_indices_pair.append([(indices[vs[0][0]], vs[0][1]) for vs in p_to_vs])
95+
96+
current = 0
97+
pair_to_int = dict[tuple[int, int], int]()
98+
99+
def unique_int(pair: tuple[int, int]) -> int:
100+
nonlocal current
101+
if pair in pair_to_int:
102+
return pair_to_int[pair]
103+
pair_to_int[pair] = current
104+
current += 1
105+
return pair_to_int[pair]
106+
107+
new_indices = [
108+
[unique_int(get_representative(i)) for i in indices] for indices in new_indices_pair
109+
]
110+
new_output = list[int]()
111+
v_to_ps = list[list[int]]()
112+
for i in output:
113+
current_v_to_ps = []
114+
for j in range(indices_to_n_pdims[i]):
115+
k = unique_int(get_representative((i, j)))
116+
if k in new_output:
117+
current_v_to_ps.append(new_output.index(k))
118+
else:
119+
current_v_to_ps.append(len(new_output))
120+
new_output.append(k)
121+
v_to_ps.append(current_v_to_ps)
122+
123+
physical = torch.einsum(*[x for y in zip(physicals, new_indices) for x in y], new_output)
124+
# Need to use the safe constructor, otherwise the dimensions may not be maximally grouped.
125+
# Maybe there is a way to fix that though.
126+
return to_most_efficient_tensor(physical, v_to_ps)
127+
128+
129+
def prepare_for_elementwise_op(
130+
t1: Tensor | int | float, t2: Tensor | int | float
131+
) -> tuple[StructuredSparseTensor, StructuredSparseTensor]:
132+
"""
133+
Prepares two SSTs of the same shape from two args, one of those being a SST, and the other being
134+
a SST, Tensor, int or float.
135+
"""
136+
137+
assert isinstance(t1, StructuredSparseTensor) or isinstance(t2, StructuredSparseTensor)
138+
139+
if isinstance(t1, int) or isinstance(t1, float):
140+
t1_ = tensor(t1, device=t2.device)
141+
else:
142+
t1_ = t1
143+
144+
if isinstance(t2, int) or isinstance(t2, float):
145+
t2_ = tensor(t2, device=t1.device)
146+
else:
147+
t2_ = t2
148+
149+
t1_, t2_ = aten.broadcast_tensors.default([t1_, t2_])
150+
t1_ = to_structured_sparse_tensor(t1_)
151+
t2_ = to_structured_sparse_tensor(t2_)
152+
153+
return t1_, t2_
154+
155+
156+
@impl(aten.mul.Tensor)
157+
def mul_Tensor(t1: Tensor | int | float, t2: Tensor | int | float) -> Tensor:
158+
# Element-wise multiplication with broadcasting
159+
t1_, t2_ = prepare_for_elementwise_op(t1, t2)
160+
all_dims = list(range(t1_.ndim))
161+
return einsum((t1_, all_dims), (t2_, all_dims), output=all_dims)
162+
163+
164+
@impl(aten.div.Tensor)
165+
def div_Tensor(t1: Tensor | int | float, t2: Tensor | int | float) -> Tensor:
166+
t1_, t2_ = prepare_for_elementwise_op(t1, t2)
167+
t2_ = StructuredSparseTensor(1.0 / t2_.physical, t2_.strides)
168+
all_dims = list(range(t1_.ndim))
169+
return einsum((t1_, all_dims), (t2_, all_dims), output=all_dims)
170+
171+
172+
@impl(aten.mul.Scalar)
173+
def mul_Scalar(t: StructuredSparseTensor, scalar) -> StructuredSparseTensor:
174+
# TODO: maybe it could be that scalar is a scalar SST and t is a normal tensor. Need to check
175+
# that
176+
177+
assert isinstance(t, StructuredSparseTensor)
178+
new_physical = aten.mul.Scalar(t.physical, scalar)
179+
return StructuredSparseTensor(new_physical, t.strides)
180+
181+
182+
@impl(aten.add.Tensor)
183+
def add_Tensor(
184+
t1: Tensor | int | float, t2: Tensor | int | float, alpha: Tensor | float = 1.0
185+
) -> StructuredSparseTensor:
186+
t1_, t2_ = prepare_for_elementwise_op(t1, t2)
187+
188+
if torch.equal(t1_.strides, t2_.strides):
189+
new_physical = t1_.physical + t2_.physical * alpha
190+
return StructuredSparseTensor(new_physical, t1_.strides)
191+
else:
192+
raise NotImplementedError()
193+
194+
195+
@impl(aten.bmm.default)
196+
def bmm_default(mat1: Tensor, mat2: Tensor) -> Tensor:
197+
assert isinstance(mat1, StructuredSparseTensor) or isinstance(mat2, StructuredSparseTensor)
198+
assert (
199+
mat1.ndim == 3
200+
and mat2.ndim == 3
201+
and mat1.shape[0] == mat2.shape[0]
202+
and mat1.shape[2] == mat2.shape[1]
203+
)
204+
205+
mat1_ = to_structured_sparse_tensor(mat1)
206+
mat2_ = to_structured_sparse_tensor(mat2)
207+
208+
# TODO: Verify that the dimension `0` of mat1_ and mat2_ have the same physical dimension sizes
209+
# decompositions. If not, can reshape to common decomposition?
210+
return einsum((mat1_, [0, 1, 2]), (mat2_, [0, 2, 3]), output=[0, 1, 3])
211+
212+
213+
@impl(aten.mm.default)
214+
def mm_default(mat1: Tensor, mat2: Tensor) -> Tensor:
215+
assert isinstance(mat1, StructuredSparseTensor) or isinstance(mat2, StructuredSparseTensor)
216+
assert mat1.ndim == 2 and mat2.ndim == 2 and mat1.shape[1] == mat2.shape[0]
217+
218+
mat1_ = to_structured_sparse_tensor(mat1)
219+
mat2_ = to_structured_sparse_tensor(mat2)
220+
221+
return einsum((mat1_, [0, 1]), (mat2_, [1, 2]), output=[0, 2])
222+
223+
224+
@impl(aten.mean.default)
225+
def mean_default(t: StructuredSparseTensor) -> Tensor:
226+
assert isinstance(t, StructuredSparseTensor)
227+
return aten.sum.default(t.physical) / t.numel()
228+
229+
230+
@impl(aten.sum.default)
231+
def sum_default(t: StructuredSparseTensor) -> Tensor:
232+
assert isinstance(t, StructuredSparseTensor)
233+
return aten.sum.default(t.physical)
234+
235+
236+
@impl(aten.sum.dim_IntList)
237+
def sum_dim_IntList(
238+
t: StructuredSparseTensor, dim: list[int], keepdim: bool = False, dtype=None
239+
) -> Tensor:
240+
assert isinstance(t, StructuredSparseTensor)
241+
242+
if dtype:
243+
raise NotImplementedError()
244+
245+
all_dims = list(range(t.ndim))
246+
result = einsum((t, all_dims), output=[d for d in all_dims if d not in dim])
247+
248+
if keepdim:
249+
for d in dim:
250+
result = result.unsqueeze(d)
251+
252+
return result

0 commit comments

Comments
 (0)