Skip to content

Commit f00377a

Browse files
committed
Add HNF decomposition, LCM and GCD.
1 parent ba6e65f commit f00377a

1 file changed

Lines changed: 205 additions & 0 deletions

File tree

src/torchjd/sparse/_linalg.py

Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import torch
22
from torch import Tensor
33

4+
# TODO: Implement in C everything in this file.
5+
46

57
def solve_int(A: Tensor, B: Tensor, tol=1e-9) -> Tensor | None:
68
"""
@@ -22,3 +24,206 @@ def solve_int(A: Tensor, B: Tensor, tol=1e-9) -> Tensor | None:
2224

2325
# TODO: Verify that the round operation cannot fail
2426
return X_rounded.to(torch.int64)
27+
28+
29+
def extended_gcd(a: int, b: int) -> tuple[int, int, int]:
30+
"""
31+
Extended Euclidean Algorithm (Python integers).
32+
Returns (g, x, y) such that a*x + b*y = g.
33+
"""
34+
# We perform the logic in standard Python int for speed on scalars
35+
# then cast back to torch tensors if needed, or return python ints.
36+
if a == 0:
37+
return b, 0, 1
38+
else:
39+
g, y, x = extended_gcd(b % a, a)
40+
return g, x - (b // a) * y, y
41+
42+
43+
def hnf_decomposition(A: Tensor) -> tuple[Tensor, Tensor, Tensor]:
44+
"""
45+
Computes the Hermite Normal Form decomposition using PyTorch.
46+
47+
Args:
48+
A: (m x n) torch.Tensor (dtype=torch.long)
49+
50+
Returns:
51+
H: (m x n) Canonical Lower Triangular HNF
52+
U: (n x n) Unimodular transform (A @ U = H)
53+
V: (n x n) Inverse Unimodular transform (H @ V = A)
54+
"""
55+
56+
H = A.clone().to(dtype=torch.long)
57+
m, n = H.shape
58+
59+
U = torch.eye(n, dtype=torch.long)
60+
V = torch.eye(n, dtype=torch.long)
61+
62+
row = 0
63+
col = 0
64+
65+
while row < m and col < n:
66+
# --- 1. Pivot Selection ---
67+
# Find first non-zero entry in current row from col onwards
68+
pivot_idx = -1
69+
70+
# We extract the row slice to CPU for faster scalar checks if on GPU
71+
# or just iterate. For HNF, strictly sequential loop is often easiest.
72+
for j in range(col, n):
73+
if H[row, j] != 0:
74+
pivot_idx = j
75+
break
76+
77+
if pivot_idx == -1:
78+
row += 1
79+
continue
80+
81+
# Swap to current column
82+
if pivot_idx != col:
83+
# Swap Columns in H and U
84+
H[:, [col, pivot_idx]] = H[:, [pivot_idx, col]]
85+
U[:, [col, pivot_idx]] = U[:, [pivot_idx, col]]
86+
# Swap ROWS in V
87+
V[[col, pivot_idx], :] = V[[pivot_idx, col], :]
88+
89+
# --- 2. Gaussian Elimination via GCD ---
90+
for j in range(col + 1, n):
91+
if H[row, j] != 0:
92+
# Extract values as python ints for GCD logic
93+
a_val = H[row, col].item()
94+
b_val = H[row, j].item()
95+
96+
g, x, y = extended_gcd(a_val, b_val)
97+
98+
# Bezout: a*x + b*y = g
99+
# c1 = -b // g, c2 = a // g
100+
c1 = -b_val // g
101+
c2 = a_val // g
102+
103+
# --- Update H (Column Ops) ---
104+
# Important: Clone columns to avoid in-place modification issues during calc
105+
col_c = H[:, col].clone()
106+
col_j = H[:, j].clone()
107+
108+
H[:, col] = col_c * x + col_j * y
109+
H[:, j] = col_c * c1 + col_j * c2
110+
111+
# --- Update U (Column Ops) ---
112+
u_c = U[:, col].clone()
113+
u_j = U[:, j].clone()
114+
U[:, col] = u_c * x + u_j * y
115+
U[:, j] = u_c * c1 + u_j * c2
116+
117+
# --- Update V (Inverse Row Ops) ---
118+
# Inverse of [[x, c1], [y, c2]] is [[c2, -c1], [-y, x]]
119+
v_r_c = V[col, :].clone()
120+
v_r_j = V[j, :].clone()
121+
V[col, :] = v_r_c * c2 - v_r_j * c1
122+
V[j, :] = v_r_c * (-y) + v_r_j * x
123+
124+
# --- 3. Enforce Positive Diagonal ---
125+
if H[row, col] < 0:
126+
H[:, col] *= -1
127+
U[:, col] *= -1
128+
V[col, :] *= -1
129+
130+
# --- 4. Canonical Reduction (Modulo) ---
131+
# Ensure 0 <= H[row, k] < H[row, col] for k < col
132+
pivot_val = H[row, col].clone()
133+
if pivot_val != 0:
134+
for j in range(col):
135+
# floor division
136+
factor = torch.div(H[row, j], pivot_val, rounding_mode="floor")
137+
138+
if factor != 0:
139+
H[:, j] -= factor * H[:, col]
140+
U[:, j] -= factor * U[:, col]
141+
V[col, :] += factor * V[j, :]
142+
143+
row += 1
144+
col += 1
145+
146+
return H, U, V
147+
148+
149+
def compute_gcd(S1: Tensor, S2: Tensor) -> tuple[Tensor, Tensor, Tensor]:
150+
"""
151+
Computes the GCD and the projection factors. i.e.
152+
S1 = G @ K1
153+
S2 = G @ K2
154+
155+
Args:
156+
S1, S2: torch.Tensors (m x n1), (m x n2)
157+
158+
Returns:
159+
G: (m x m) The Matrix GCD (Canonical Base)
160+
K1: (m x n1) Factors for S1
161+
K2: (m x n2) Factors for S2
162+
"""
163+
assert S1.shape[0] == S2.shape[0], "Virtual dimension mismatch"
164+
m = S1.shape[0]
165+
n1 = S1.shape[1]
166+
167+
# 1. Stack: [S1 | S2]
168+
A = torch.cat([S1, S2], dim=1)
169+
170+
# 2. Decompose
171+
H, U, V = hnf_decomposition(A)
172+
173+
# 3. Extract G (First m columns of H)
174+
G = H[:, :m]
175+
176+
# 4. Extract Factors from V
177+
# S = G @ V_top.
178+
# V tracks the inverse transforms, so it contains the coefficients K directly.
179+
V_active = V[:m, :] # Top m rows
180+
181+
K1 = V_active[:, :n1]
182+
K2 = V_active[:, n1:]
183+
184+
return G, K1, K2
185+
186+
187+
def compute_lcm(S1, S2):
188+
"""
189+
Computes the Matrix LCM (L) and the Multiples (M1, M2), i.e.
190+
L = S1 @ M1 = S2 @ M2
191+
192+
Returns:
193+
L: (m x m) The Matrix LCM
194+
M1: (n1 x m) Factor such that L = S1 @ M1
195+
M2: (n2 x m) Factor such that L = S2 @ M2
196+
"""
197+
m = S1.shape[0]
198+
n1 = S1.shape[1]
199+
200+
# 1. Kernel Setup: [S1 | -S2]
201+
B = torch.cat([S1, -S2], dim=1)
202+
203+
# 2. Decompose to find Kernel
204+
H_B, U_B, _ = hnf_decomposition(B)
205+
206+
# 3. Find Zero Columns in H_B (Kernel basis)
207+
# Sum abs values down columns
208+
col_mags = torch.sum(torch.abs(H_B), dim=0)
209+
zero_indices = torch.nonzero(col_mags == 0, as_tuple=True)[0]
210+
211+
if len(zero_indices) == 0:
212+
return torch.zeros((m, m), dtype=torch.long)
213+
214+
# 4. Extract Kernel Basis
215+
# U_B columns corresponding to H_B zeros are the kernel generators
216+
kernel_basis = U_B[:, zero_indices]
217+
218+
# 5. Map back to Image Space
219+
# The kernel vector is [u; v]. We need u (top n1 rows).
220+
# Intersection = S1 @ u
221+
u_parts = kernel_basis[:n1, :]
222+
L_generators = S1 @ u_parts
223+
224+
# 6. Canonicalize L
225+
# The generators might be redundant or non-square.
226+
# Run HNF one last time to get the unique square LCM matrix.
227+
L, _, _ = hnf_decomposition(L_generators)
228+
229+
return L[:, :m]

0 commit comments

Comments
 (0)