11import torch
22from torch import Tensor
33
4+ # TODO: Implement in C everything in this file.
5+
46
57def solve_int (A : Tensor , B : Tensor , tol = 1e-9 ) -> Tensor | None :
68 """
@@ -22,3 +24,206 @@ def solve_int(A: Tensor, B: Tensor, tol=1e-9) -> Tensor | None:
2224
2325 # TODO: Verify that the round operation cannot fail
2426 return X_rounded .to (torch .int64 )
27+
28+
29+ def extended_gcd (a : int , b : int ) -> tuple [int , int , int ]:
30+ """
31+ Extended Euclidean Algorithm (Python integers).
32+ Returns (g, x, y) such that a*x + b*y = g.
33+ """
34+ # We perform the logic in standard Python int for speed on scalars
35+ # then cast back to torch tensors if needed, or return python ints.
36+ if a == 0 :
37+ return b , 0 , 1
38+ else :
39+ g , y , x = extended_gcd (b % a , a )
40+ return g , x - (b // a ) * y , y
41+
42+
43+ def hnf_decomposition (A : Tensor ) -> tuple [Tensor , Tensor , Tensor ]:
44+ """
45+ Computes the Hermite Normal Form decomposition using PyTorch.
46+
47+ Args:
48+ A: (m x n) torch.Tensor (dtype=torch.long)
49+
50+ Returns:
51+ H: (m x n) Canonical Lower Triangular HNF
52+ U: (n x n) Unimodular transform (A @ U = H)
53+ V: (n x n) Inverse Unimodular transform (H @ V = A)
54+ """
55+
56+ H = A .clone ().to (dtype = torch .long )
57+ m , n = H .shape
58+
59+ U = torch .eye (n , dtype = torch .long )
60+ V = torch .eye (n , dtype = torch .long )
61+
62+ row = 0
63+ col = 0
64+
65+ while row < m and col < n :
66+ # --- 1. Pivot Selection ---
67+ # Find first non-zero entry in current row from col onwards
68+ pivot_idx = - 1
69+
70+ # We extract the row slice to CPU for faster scalar checks if on GPU
71+ # or just iterate. For HNF, strictly sequential loop is often easiest.
72+ for j in range (col , n ):
73+ if H [row , j ] != 0 :
74+ pivot_idx = j
75+ break
76+
77+ if pivot_idx == - 1 :
78+ row += 1
79+ continue
80+
81+ # Swap to current column
82+ if pivot_idx != col :
83+ # Swap Columns in H and U
84+ H [:, [col , pivot_idx ]] = H [:, [pivot_idx , col ]]
85+ U [:, [col , pivot_idx ]] = U [:, [pivot_idx , col ]]
86+ # Swap ROWS in V
87+ V [[col , pivot_idx ], :] = V [[pivot_idx , col ], :]
88+
89+ # --- 2. Gaussian Elimination via GCD ---
90+ for j in range (col + 1 , n ):
91+ if H [row , j ] != 0 :
92+ # Extract values as python ints for GCD logic
93+ a_val = H [row , col ].item ()
94+ b_val = H [row , j ].item ()
95+
96+ g , x , y = extended_gcd (a_val , b_val )
97+
98+ # Bezout: a*x + b*y = g
99+ # c1 = -b // g, c2 = a // g
100+ c1 = - b_val // g
101+ c2 = a_val // g
102+
103+ # --- Update H (Column Ops) ---
104+ # Important: Clone columns to avoid in-place modification issues during calc
105+ col_c = H [:, col ].clone ()
106+ col_j = H [:, j ].clone ()
107+
108+ H [:, col ] = col_c * x + col_j * y
109+ H [:, j ] = col_c * c1 + col_j * c2
110+
111+ # --- Update U (Column Ops) ---
112+ u_c = U [:, col ].clone ()
113+ u_j = U [:, j ].clone ()
114+ U [:, col ] = u_c * x + u_j * y
115+ U [:, j ] = u_c * c1 + u_j * c2
116+
117+ # --- Update V (Inverse Row Ops) ---
118+ # Inverse of [[x, c1], [y, c2]] is [[c2, -c1], [-y, x]]
119+ v_r_c = V [col , :].clone ()
120+ v_r_j = V [j , :].clone ()
121+ V [col , :] = v_r_c * c2 - v_r_j * c1
122+ V [j , :] = v_r_c * (- y ) + v_r_j * x
123+
124+ # --- 3. Enforce Positive Diagonal ---
125+ if H [row , col ] < 0 :
126+ H [:, col ] *= - 1
127+ U [:, col ] *= - 1
128+ V [col , :] *= - 1
129+
130+ # --- 4. Canonical Reduction (Modulo) ---
131+ # Ensure 0 <= H[row, k] < H[row, col] for k < col
132+ pivot_val = H [row , col ].clone ()
133+ if pivot_val != 0 :
134+ for j in range (col ):
135+ # floor division
136+ factor = torch .div (H [row , j ], pivot_val , rounding_mode = "floor" )
137+
138+ if factor != 0 :
139+ H [:, j ] -= factor * H [:, col ]
140+ U [:, j ] -= factor * U [:, col ]
141+ V [col , :] += factor * V [j , :]
142+
143+ row += 1
144+ col += 1
145+
146+ return H , U , V
147+
148+
149+ def compute_gcd (S1 : Tensor , S2 : Tensor ) -> tuple [Tensor , Tensor , Tensor ]:
150+ """
151+ Computes the GCD and the projection factors. i.e.
152+ S1 = G @ K1
153+ S2 = G @ K2
154+
155+ Args:
156+ S1, S2: torch.Tensors (m x n1), (m x n2)
157+
158+ Returns:
159+ G: (m x m) The Matrix GCD (Canonical Base)
160+ K1: (m x n1) Factors for S1
161+ K2: (m x n2) Factors for S2
162+ """
163+ assert S1 .shape [0 ] == S2 .shape [0 ], "Virtual dimension mismatch"
164+ m = S1 .shape [0 ]
165+ n1 = S1 .shape [1 ]
166+
167+ # 1. Stack: [S1 | S2]
168+ A = torch .cat ([S1 , S2 ], dim = 1 )
169+
170+ # 2. Decompose
171+ H , U , V = hnf_decomposition (A )
172+
173+ # 3. Extract G (First m columns of H)
174+ G = H [:, :m ]
175+
176+ # 4. Extract Factors from V
177+ # S = G @ V_top.
178+ # V tracks the inverse transforms, so it contains the coefficients K directly.
179+ V_active = V [:m , :] # Top m rows
180+
181+ K1 = V_active [:, :n1 ]
182+ K2 = V_active [:, n1 :]
183+
184+ return G , K1 , K2
185+
186+
187+ def compute_lcm (S1 , S2 ):
188+ """
189+ Computes the Matrix LCM (L) and the Multiples (M1, M2), i.e.
190+ L = S1 @ M1 = S2 @ M2
191+
192+ Returns:
193+ L: (m x m) The Matrix LCM
194+ M1: (n1 x m) Factor such that L = S1 @ M1
195+ M2: (n2 x m) Factor such that L = S2 @ M2
196+ """
197+ m = S1 .shape [0 ]
198+ n1 = S1 .shape [1 ]
199+
200+ # 1. Kernel Setup: [S1 | -S2]
201+ B = torch .cat ([S1 , - S2 ], dim = 1 )
202+
203+ # 2. Decompose to find Kernel
204+ H_B , U_B , _ = hnf_decomposition (B )
205+
206+ # 3. Find Zero Columns in H_B (Kernel basis)
207+ # Sum abs values down columns
208+ col_mags = torch .sum (torch .abs (H_B ), dim = 0 )
209+ zero_indices = torch .nonzero (col_mags == 0 , as_tuple = True )[0 ]
210+
211+ if len (zero_indices ) == 0 :
212+ return torch .zeros ((m , m ), dtype = torch .long )
213+
214+ # 4. Extract Kernel Basis
215+ # U_B columns corresponding to H_B zeros are the kernel generators
216+ kernel_basis = U_B [:, zero_indices ]
217+
218+ # 5. Map back to Image Space
219+ # The kernel vector is [u; v]. We need u (top n1 rows).
220+ # Intersection = S1 @ u
221+ u_parts = kernel_basis [:n1 , :]
222+ L_generators = S1 @ u_parts
223+
224+ # 6. Canonicalize L
225+ # The generators might be redundant or non-square.
226+ # Run HNF one last time to get the unique square LCM matrix.
227+ L , _ , _ = hnf_decomposition (L_generators )
228+
229+ return L [:, :m ]
0 commit comments