From 684b7524982c0eb80db2bf6486a6e24718c83066 Mon Sep 17 00:00:00 2001 From: Max Balandat Date: Thu, 14 Jan 2021 22:04:56 -0800 Subject: [PATCH 1/2] Return (inv) root of KroneckerProductAddedDiagLinearOoperator as lazy Previously, `_root_decomposition` and `_inv_root_decomposition` were returning the (inv) root from the eigendecomposition as a dense tensor, which can be inefficient. This now returns the root as `MatmulLazyTensor` instead. E.g. a matrix vector product of the root with some vector `v` is now implicitly computed as `q_matrix @ (evals \dot v)` rather than `(q_matrix @ diag(evals)) @ v`, which can make a big difference since `q_matrix` is a Kronecker product. This can help runtime, but more importantly it significantly reduces memory footprint, since we don't need to instantiate the (inv) root, but only the constitutent components. This is a clone of https://github.com/cornellius-gp/gpytorch/pull/1430 --- ...cker_product_added_diag_linear_operator.py | 66 +++++++++---------- linear_operator/operators/linear_operator.py | 2 +- linear_operator/utils/__init__.py | 3 +- linear_operator/utils/interpolation.py | 6 +- linear_operator/utils/lanczos.py | 3 +- linear_operator/utils/sparse.py | 9 +-- ...cker_product_added_diag_linear_operator.py | 16 +++-- 7 files changed, 51 insertions(+), 54 deletions(-) diff --git a/linear_operator/operators/kronecker_product_added_diag_linear_operator.py b/linear_operator/operators/kronecker_product_added_diag_linear_operator.py index a59d822..110a3e0 100644 --- a/linear_operator/operators/kronecker_product_added_diag_linear_operator.py +++ b/linear_operator/operators/kronecker_product_added_diag_linear_operator.py @@ -5,16 +5,13 @@ import torch from .added_diag_linear_operator import AddedDiagLinearOperator -from .diag_linear_operator import DiagLinearOperator +from .diag_linear_operator import ConstantDiagLinearOperator, DiagLinearOperator +from .matmul_linear_operator import MatmulLinearOperator class KroneckerProductAddedDiagLinearOperator(AddedDiagLinearOperator): def __init__(self, *linear_operators, preconditioner_override=None): - # TODO: implement the woodbury formula for diagonal tensors that are non constants. - - super(KroneckerProductAddedDiagLinearOperator, self).__init__( - *linear_operators, preconditioner_override=preconditioner_override - ) + super().__init__(*linear_operators, preconditioner_override=preconditioner_override) if len(linear_operators) > 2: raise RuntimeError("An AddedDiagLinearOperator can only have two components") elif isinstance(linear_operators[0], DiagLinearOperator): @@ -34,12 +31,7 @@ def inv_quad_logdet(self, inv_quad_rhs=None, logdet=False, reduce_inv_quad=True) inv_quad_term, _ = super().inv_quad_logdet( inv_quad_rhs=inv_quad_rhs, logdet=False, reduce_inv_quad=reduce_inv_quad ) - - if logdet is not False: - logdet_term = self._logdet() - else: - logdet_term = None - + logdet_term = self._logdet() if logdet else None return inv_quad_term, logdet_term def _logdet(self): @@ -53,33 +45,41 @@ def _preconditioner(self): return None, None, None def _solve(self, rhs, preconditioner=None, num_tridiag=0): - # we do the solve in double for numerical stability issues - # TODO: Use fp64 registry once #1213 is addressed + if isinstance(self.diag_tensor, ConstantDiagLinearOperator): + # we do the solve in double for numerical stability issues + # TODO: Use fp64 registry once #1213 is addressed + + rhs_dtype = rhs.dtype + rhs = rhs.double() + + evals, q_matrix = self.linear_operator.symeig(eigenvectors=True) + evals, q_matrix = evals.double(), q_matrix.double() - rhs_dtype = rhs.dtype - rhs = rhs.double() + evals_plus_diagonal = evals + self.diag_tensor.diag() + evals_root = evals_plus_diagonal.pow(0.5) + inv_mat_sqrt = DiagLinearOperator(evals_root.reciprocal()) - evals, q_matrix = self.linear_operator.symeig(eigenvectors=True) - evals, q_matrix = evals.double(), q_matrix.double() + res = q_matrix.transpose(-2, -1).matmul(rhs) + res2 = inv_mat_sqrt.matmul(res) - evals_plus_diagonal = evals + self.diag_tensor.diag() - evals_root = evals_plus_diagonal.pow(0.5) - inv_mat_sqrt = DiagLinearOperator(evals_root.reciprocal()) + lhs = q_matrix.matmul(inv_mat_sqrt) + return lhs.matmul(res2).type(rhs_dtype) - res = q_matrix.transpose(-2, -1).matmul(rhs) - res2 = inv_mat_sqrt.matmul(res) + # TODO: implement woodbury formula for non-constant Kronecker-structured diagonal operators - lhs = q_matrix.matmul(inv_mat_sqrt) - return lhs.matmul(res2).type(rhs_dtype) + return super()._solve(rhs, preconditioner=None, num_tridiag=0) def _root_decomposition(self): - evals, q_matrix = self.linear_operator.symeig(eigenvectors=True) - updated_evals = DiagLinearOperator((evals + self.diag_tensor.diag()).pow(0.5)) - matrix_root = q_matrix.matmul(updated_evals) - return matrix_root + if isinstance(self.diag_tensor, ConstantDiagLinearOperator): + # we can be use eigendecomposition and shift the eigenvalues + evals, q_matrix = self.linear_operator.symeig(eigenvectors=True) + updated_evals = DiagLinearOperator((evals + self.diag_tensor.diag()).pow(0.5)) + return MatmulLinearOperator(q_matrix, updated_evals) + return super()._root_decomposition() def _root_inv_decomposition(self, initial_vectors=None): - evals, q_matrix = self.linear_operator.symeig(eigenvectors=True) - inv_sqrt_evals = DiagLinearOperator((evals + self.diag_tensor.diag()).pow(-0.5)) - matrix_inv_root = q_matrix.matmul(inv_sqrt_evals) - return matrix_inv_root + if isinstance(self.diag_tensor, ConstantDiagLinearOperator): + evals, q_matrix = self.linear_operator.symeig(eigenvectors=True) + inv_sqrt_evals = DiagLinearOperator((evals + self.diag_tensor.diag()).pow(-0.5)) + return MatmulLinearOperator(q_matrix, inv_sqrt_evals) + return super()._root_inv_decomposition(initial_vectors=initial_vectors) diff --git a/linear_operator/operators/linear_operator.py b/linear_operator/operators/linear_operator.py index 5a83eee..790ddd6 100644 --- a/linear_operator/operators/linear_operator.py +++ b/linear_operator/operators/linear_operator.py @@ -1925,7 +1925,7 @@ def __getitem__(self, index): # Alternatively, if we're using tensor indices and losing dimensions, use self._get_indices if row_col_are_absorbed: # Convert all indices into tensor indices - *batch_indices, row_index, col_index, = _convert_indices_to_tensors( + (*batch_indices, row_index, col_index,) = _convert_indices_to_tensors( self, (*batch_indices, row_index, col_index) ) res = self._get_indices(row_index, col_index, *batch_indices) diff --git a/linear_operator/utils/__init__.py b/linear_operator/utils/__init__.py index 2ee2eb7..a753cfd 100644 --- a/linear_operator/utils/__init__.py +++ b/linear_operator/utils/__init__.py @@ -9,8 +9,7 @@ def prod(items): - """ - """ + """""" if len(items): res = items[0] for item in items[1:]: diff --git a/linear_operator/utils/interpolation.py b/linear_operator/utils/interpolation.py index dad521b..c7c39a2 100644 --- a/linear_operator/utils/interpolation.py +++ b/linear_operator/utils/interpolation.py @@ -156,8 +156,7 @@ def interpolate(self, x_grid: List[torch.Tensor], x_target: torch.Tensor, interp def left_interp(interp_indices, interp_values, rhs): - """ - """ + """""" is_vector = rhs.ndimension() == 1 if is_vector: @@ -181,8 +180,7 @@ def left_interp(interp_indices, interp_values, rhs): def left_t_interp(interp_indices, interp_values, rhs, output_dim): - """ - """ + """""" from .. import dsmm is_vector = rhs.ndimension() == 1 diff --git a/linear_operator/utils/lanczos.py b/linear_operator/utils/lanczos.py index ad3ebf5..999cc12 100644 --- a/linear_operator/utils/lanczos.py +++ b/linear_operator/utils/lanczos.py @@ -18,8 +18,7 @@ def lanczos_tridiag( num_init_vecs=1, tol=1e-5, ): - """ - """ + """""" # Determine batch mode multiple_init_vecs = False diff --git a/linear_operator/utils/sparse.py b/linear_operator/utils/sparse.py index d3729eb..cc0d332 100644 --- a/linear_operator/utils/sparse.py +++ b/linear_operator/utils/sparse.py @@ -140,8 +140,7 @@ def sparse_eye(size): def sparse_getitem(sparse, idxs): - """ - """ + """""" if not isinstance(idxs, tuple): idxs = (idxs,) @@ -201,8 +200,7 @@ def sparse_getitem(sparse, idxs): def sparse_repeat(sparse, *repeat_sizes): - """ - """ + """""" if len(repeat_sizes) == 1 and isinstance(repeat_sizes, tuple): repeat_sizes = repeat_sizes[0] @@ -243,8 +241,7 @@ def sparse_repeat(sparse, *repeat_sizes): def to_sparse(dense): - """ - """ + """""" mask = dense.ne(0) indices = mask.nonzero(as_tuple=False) if indices.storage(): diff --git a/test/operators/test_kronecker_product_added_diag_linear_operator.py b/test/operators/test_kronecker_product_added_diag_linear_operator.py index e2584ba..f482367 100644 --- a/test/operators/test_kronecker_product_added_diag_linear_operator.py +++ b/test/operators/test_kronecker_product_added_diag_linear_operator.py @@ -9,6 +9,7 @@ from linear_operator import settings from linear_operator.operators import ( + ConstantDiagLinearOperator, DenseLinearOperator, DiagLinearOperator, KroneckerProductAddedDiagLinearOperator, @@ -23,7 +24,7 @@ class TestKroneckerProductAddedDiagLinearOperator(unittest.TestCase, LinearOpera should_call_lanczos = False should_call_cg = False - def create_linear_operator(self): + def create_linear_operator(self, constant_diag=True): a = torch.tensor([[4, 0, 2], [0, 3, -1], [2, -1, 3]], dtype=torch.float) b = torch.tensor([[2, 1], [1, 2]], dtype=torch.float) c = torch.tensor([[4, 0.5, 1, 0], [0.5, 4, -1, 0], [1, -1, 3, 0], [0, 0, 0, 4]], dtype=torch.float) @@ -33,10 +34,13 @@ def create_linear_operator(self): kp_linear_operator = KroneckerProductLinearOperator( DenseLinearOperator(a), DenseLinearOperator(b), DenseLinearOperator(c) ) - - return KroneckerProductAddedDiagLinearOperator( - kp_linear_operator, DiagLinearOperator(0.1 * torch.ones(kp_linear_operator.shape[-1])) - ) + if constant_diag: + diag_linear_operator = ConstantDiagLinearOperator( + torch.tensor([0.25], dtype=torch.float), kp_linear_operator.shape[-1], + ) + else: + diag_linear_operator = DiagLinearOperator(0.5 * torch.rand(kp_linear_operator.shape[-1], dtype=torch.float)) + return KroneckerProductAddedDiagLinearOperator(kp_linear_operator, diag_linear_operator) def evaluate_linear_operator(self, linear_operator): tensor = linear_operator._linear_operator.to_dense() @@ -59,7 +63,7 @@ def test_root_inv_decomposition_no_cholesky(self): with mock.patch.object(linear_operator, "cholesky") as chol_mock: root_approx = linear_operator.root_inv_decomposition() res = root_approx.matmul(test_mat) - actual = linear_operator.inv_matmul(test_mat) + actual = torch.solve(test_mat, linear_operator.to_dense()).solution self.assertAllClose(res, actual, rtol=0.05, atol=0.02) chol_mock.assert_not_called() From db00d44e89c41a2bd5d6fcccc5983db7167201bd Mon Sep 17 00:00:00 2001 From: Max Balandat Date: Fri, 15 Jan 2021 07:31:39 -0800 Subject: [PATCH 2/2] Fall back for other funcitons if necesary, set _diag_is_constant flag --- ...cker_product_added_diag_linear_operator.py | 35 +++++++++++-------- 1 file changed, 20 insertions(+), 15 deletions(-) diff --git a/linear_operator/operators/kronecker_product_added_diag_linear_operator.py b/linear_operator/operators/kronecker_product_added_diag_linear_operator.py index 110a3e0..c821389 100644 --- a/linear_operator/operators/kronecker_product_added_diag_linear_operator.py +++ b/linear_operator/operators/kronecker_product_added_diag_linear_operator.py @@ -24,28 +24,33 @@ def __init__(self, *linear_operators, preconditioner_override=None): raise RuntimeError( "One of the LinearOperators input to AddedDiagLinearOperator must be a DiagLinearOperator!" ) + self._diag_is_constant = isinstance(self.diag_tensor, ConstantDiagLinearOperator) def inv_quad_logdet(self, inv_quad_rhs=None, logdet=False, reduce_inv_quad=True): - # we want to call the standard InvQuadLogDet to easily get the probe vectors and do the - # solve but we only want to cache the probe vectors for the backwards - inv_quad_term, _ = super().inv_quad_logdet( - inv_quad_rhs=inv_quad_rhs, logdet=False, reduce_inv_quad=reduce_inv_quad - ) - logdet_term = self._logdet() if logdet else None - return inv_quad_term, logdet_term + if self._diag_is_constant: + # we want to call the standard InvQuadLogDet to easily get the probe vectors and do the + # solve but we only want to cache the probe vectors for the backwards + inv_quad_term, _ = super().inv_quad_logdet( + inv_quad_rhs=inv_quad_rhs, logdet=False, reduce_inv_quad=reduce_inv_quad + ) + logdet_term = self._logdet() if logdet else None + return inv_quad_term, logdet_term + return super().inv_quad_logdet(inv_quad_rhs=inv_quad_rhs, logdet=logdet, reduce_inv_quad=reduce_inv_quad) def _logdet(self): - # symeig requires computing the eigenvectors so that it's differentiable - evals, _ = self.linear_operator.symeig(eigenvectors=True) - evals_plus_diag = evals + self.diag_tensor.diag() - return torch.log(evals_plus_diag).sum(dim=-1) + if self._diag_is_constant: + # symeig requires computing the eigenvectors so that it's differentiable + evals, _ = self.linear_operator.symeig(eigenvectors=True) + evals_plus_diag = evals + self.diag_tensor.diag() + return torch.log(evals_plus_diag).sum(dim=-1) + return super()._logdet() def _preconditioner(self): # solves don't use CG so don't waste time computing it return None, None, None def _solve(self, rhs, preconditioner=None, num_tridiag=0): - if isinstance(self.diag_tensor, ConstantDiagLinearOperator): + if self._diag_is_constant: # we do the solve in double for numerical stability issues # TODO: Use fp64 registry once #1213 is addressed @@ -67,10 +72,10 @@ def _solve(self, rhs, preconditioner=None, num_tridiag=0): # TODO: implement woodbury formula for non-constant Kronecker-structured diagonal operators - return super()._solve(rhs, preconditioner=None, num_tridiag=0) + return super()._solve(rhs, preconditioner=preconditioner, num_tridiag=num_tridiag) def _root_decomposition(self): - if isinstance(self.diag_tensor, ConstantDiagLinearOperator): + if self._diag_is_constant: # we can be use eigendecomposition and shift the eigenvalues evals, q_matrix = self.linear_operator.symeig(eigenvectors=True) updated_evals = DiagLinearOperator((evals + self.diag_tensor.diag()).pow(0.5)) @@ -78,7 +83,7 @@ def _root_decomposition(self): return super()._root_decomposition() def _root_inv_decomposition(self, initial_vectors=None): - if isinstance(self.diag_tensor, ConstantDiagLinearOperator): + if self._diag_is_constant: evals, q_matrix = self.linear_operator.symeig(eigenvectors=True) inv_sqrt_evals = DiagLinearOperator((evals + self.diag_tensor.diag()).pow(-0.5)) return MatmulLinearOperator(q_matrix, inv_sqrt_evals)