|
| 1 | +from scipy.sparse.linalg import LinearOperator, aslinearoperator |
| 2 | +from scipy.sparse import diags, issparse |
| 3 | +import numpy as np |
| 4 | + |
| 5 | + |
| 6 | +def is_sparse_or_lin_op(a): |
| 7 | + return issparse(a) or isinstance(a, LinearOperator) |
| 8 | + |
| 9 | + |
| 10 | +def safe_hstack(tup): |
| 11 | + |
| 12 | + if any(is_sparse_or_lin_op(t) for t in tup): |
| 13 | + return HStacked(tup) |
| 14 | + else: |
| 15 | + return np.hstack(tup) |
| 16 | + |
| 17 | + |
| 18 | +class HStacked(LinearOperator): |
| 19 | + """ |
| 20 | + Represents np.hstack |
| 21 | + """ |
| 22 | + def __init__(self, tup): |
| 23 | + |
| 24 | + n_rows = tup[0].shape[0] |
| 25 | + self.tup_n_cols = [] |
| 26 | + self.tup = [] |
| 27 | + for t in tup: |
| 28 | + assert t.shape[0] == n_rows |
| 29 | + if t.ndim == 0: |
| 30 | + self.tup.append(t.reshape(-1, 1)) |
| 31 | + else: |
| 32 | + self.tup.append(t) |
| 33 | + |
| 34 | + self.tup_n_cols.append(t.shape[1]) |
| 35 | + |
| 36 | + shape = (n_rows, sum(self.tup_n_cols)) |
| 37 | + |
| 38 | + dtype = tup[0].dtype |
| 39 | + super().__init__(dtype=dtype, shape=shape) |
| 40 | + |
| 41 | + def _matvec(self, x): |
| 42 | + out = [] |
| 43 | + left_idx = 0 |
| 44 | + right_idx = 0 |
| 45 | + for idx, n_cols in enumerate(self.tup_n_cols): |
| 46 | + right_idx += n_cols |
| 47 | + out.append(self.tup[idx] @ x[left_idx:right_idx]) |
| 48 | + left_idx += n_cols |
| 49 | + |
| 50 | + return sum(o for o in out) |
| 51 | + |
| 52 | + def _rmatvec(self, x): |
| 53 | + return np.concatenate([mat.T @ x for mat in self.tup]) |
| 54 | + |
| 55 | + |
| 56 | +class OnesOuterVec(LinearOperator): |
| 57 | + """ |
| 58 | + Represents the outer product 1_n vec.T where 1_n is the vector of ones |
| 59 | + """ |
| 60 | + def __init__(self, n_rows, vec): |
| 61 | + self.vec = np.asarray(vec).reshape(-1) |
| 62 | + shape = (n_rows, self.vec.shape[0]) |
| 63 | + dtype = self.vec.dtype |
| 64 | + super().__init__(dtype=dtype, shape=shape) |
| 65 | + |
| 66 | + def _matvec(self, x): |
| 67 | + return np.repeat(self.vec.T.dot(x), self.shape[0]) |
| 68 | + |
| 69 | + def _rmatvec(self, x): |
| 70 | + return self.vec * x.sum() |
| 71 | + |
| 72 | + |
| 73 | +def centered_operator(X, center): |
| 74 | + return aslinearoperator(X) - OnesOuterVec(X.shape[0], center) |
| 75 | + |
| 76 | + |
| 77 | +def center_scale_sparse(X, X_offset=None, X_scale=None): |
| 78 | + """ |
| 79 | + Returns a linear operator representing a centered and scaled matrix |
| 80 | +
|
| 81 | + X_cent_scale = (X - X_offset) @ diags(1 / X_scale) |
| 82 | +
|
| 83 | + Output |
| 84 | + ------ |
| 85 | + X_cent_scale: LinearOperator |
| 86 | + """ |
| 87 | + if X_offset is None and X_scale is None: |
| 88 | + return X |
| 89 | + |
| 90 | + if X_offset is not None and X_scale is not None: |
| 91 | + X_offset_scale = X_offset / X_scale |
| 92 | + X_offset_scale = np.array(X_offset_scale).reshape(-1, 1) |
| 93 | + |
| 94 | + elif X_offset is not None: |
| 95 | + X_offset_scale = X_offset |
| 96 | + |
| 97 | + if X_scale is not None: |
| 98 | + X_ = X @ diags(1 / X_scale) |
| 99 | + else: |
| 100 | + X_ = X |
| 101 | + |
| 102 | + return centered_operator(X=X_, center=X_offset_scale) |
| 103 | + |
| 104 | + |
| 105 | +def safe_row_scaled(mat, s): |
| 106 | + if is_sparse_or_lin_op(mat): |
| 107 | + return RowScaled(mat=mat, s=s) |
| 108 | + else: |
| 109 | + return diags(s) @ mat |
| 110 | + |
| 111 | + |
| 112 | +def safe_col_scaled(mat, s): |
| 113 | + if is_sparse_or_lin_op(mat): |
| 114 | + return ColScaled(mat=mat, s=s) |
| 115 | + else: |
| 116 | + return mat @diags(s) |
| 117 | + |
| 118 | + |
| 119 | +class RowScaled(LinearOperator): |
| 120 | + def __init__(self, mat, s): |
| 121 | + self.s = np.array(s).reshape(-1).astype(mat.dtype) |
| 122 | + assert len(self.s) == mat.shape[0] |
| 123 | + self.s = diags(self.s) |
| 124 | + self.mat = mat |
| 125 | + super().__init__(dtype=mat.dtype, shape=mat.shape) |
| 126 | + |
| 127 | + def _matvec(self, x): |
| 128 | + return self.s @ (self.mat @ x) |
| 129 | + |
| 130 | + def _rmatvec(self, x): |
| 131 | + return self.mat.T @ (self.s @ x) |
| 132 | + |
| 133 | + |
| 134 | +class ColScaled(LinearOperator): |
| 135 | + def __init__(self, mat, s): |
| 136 | + self.s = np.array(s).reshape(-1).astype(mat.dtype) |
| 137 | + assert len(self.s) == mat.shape[1] |
| 138 | + self.s = diags(self.s) |
| 139 | + self.mat = mat |
| 140 | + super().__init__(dtype=mat.dtype, shape=mat.shape) |
| 141 | + |
| 142 | + def _matvec(self, x): |
| 143 | + return self.mat @ (self.s @ x) |
| 144 | + |
| 145 | + def _rmatvec(self, x): |
| 146 | + return self.s @ (self.mat.T @ x) |
0 commit comments