Skip to content

Commit 4dbce6d

Browse files
committed
Revamp compute_gcd
1 parent 4f19317 commit 4dbce6d

1 file changed

Lines changed: 18 additions & 13 deletions

File tree

src/torchjd/sparse/_linalg.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -161,32 +161,37 @@ def compute_gcd(S1: Tensor, S2: Tensor) -> tuple[Tensor, Tensor, Tensor]:
161161
K2: (m x n2) Factors for S2
162162
"""
163163
assert S1.shape[0] == S2.shape[0], "Virtual dimension mismatch"
164-
m = S1.shape[0]
165-
n1 = S1.shape[1]
164+
m, n1 = S1.shape
166165

167-
# 1. Stack: [S1 | S2]
168166
A = torch.cat([S1, S2], dim=1)
169-
170-
# 2. Decompose
171167
H, U, V = hnf_decomposition(A)
172168

169+
# H = [S1 | S2] @ U
170+
# [S1 | S2] = H @ V
171+
#
172+
# S1 = H @ V[:, :m1]
173+
# S2 = H @ V[:, m1:]
174+
#
175+
# K1 = V[:, :m1]
176+
# K2 = V[:, m1:]
177+
# G = H
178+
#
179+
# S1 = G @ K1
180+
# S2 = G @ K2
181+
#
182+
# SST(p1, S1) = SST(SST(p1, K1), G)
183+
# SST(p2, S2) = SST(SST(p2, K2), G)
184+
173185
col_magnitudes = torch.sum(torch.abs(H), dim=0)
174-
# Find the last index that is non-zero.
175186
non_zero_indices = torch.nonzero(col_magnitudes, as_tuple=True)[0]
176187

177188
if len(non_zero_indices) == 0:
178189
rank = 0
179190
else:
180191
rank = non_zero_indices.max().item() + 1
181192

182-
# 3. Extract G (Compact Basis)
183-
# We only take the first 'rank' columns.
184193
G = H[:, :rank]
185-
186-
# 4. Extract Factors from V
187-
# S = G @ V_top.
188-
# V tracks the inverse transforms, so it contains the coefficients K directly.
189-
V_active = V[:m, :] # Top m rows
194+
V_active = V[:rank, :]
190195

191196
K1 = V_active[:, :n1]
192197
K2 = V_active[:, n1:]

0 commit comments

Comments
 (0)