Skip to content

Commit 4f19317

Browse files
committed
Improve GCD for tall stride matrices.
1 parent f00377a commit 4f19317

1 file changed

Lines changed: 12 additions & 2 deletions

File tree

src/torchjd/sparse/_linalg.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,8 +170,18 @@ def compute_gcd(S1: Tensor, S2: Tensor) -> tuple[Tensor, Tensor, Tensor]:
170170
# 2. Decompose
171171
H, U, V = hnf_decomposition(A)
172172

173-
# 3. Extract G (First m columns of H)
174-
G = H[:, :m]
173+
col_magnitudes = torch.sum(torch.abs(H), dim=0)
174+
# Find the last index that is non-zero.
175+
non_zero_indices = torch.nonzero(col_magnitudes, as_tuple=True)[0]
176+
177+
if len(non_zero_indices) == 0:
178+
rank = 0
179+
else:
180+
rank = non_zero_indices.max().item() + 1
181+
182+
# 3. Extract G (Compact Basis)
183+
# We only take the first 'rank' columns.
184+
G = H[:, :rank]
175185

176186
# 4. Extract Factors from V
177187
# S = G @ V_top.

0 commit comments

Comments
 (0)