Skip to content
Snippets Groups Projects

Created nice wrapper functions around the njit functions for convenience

Merged Anne Poot requested to merge numba-wrapper into master
1 file
+ 26
8
Compare changes
  • Side-by-side
  • Inline
+ 26
8
import numpy as np
import scipy.sparse as spsp
import scipy.sparse.linalg as spspla
from numba import njit
class ICholPrecon:
@@ -13,17 +12,15 @@ class ICholPrecon:
def update(self, sourcematrix):
self._sourcematrix = sourcematrix
self._L = spsp.tril(self._sourcematrix, format='csr')
ichol_data = incomplete_cholesky(self._L.data, self._L.indices, self._L.indptr)
self._L.data = ichol_data
self._L = incomplete_cholesky(self._sourcematrix)
self._LT = self._L.T.tocsr()
def dot(self, lhs):
return self._L @ (self._L.T @ lhs)
def solve(self, rhs):
tmp = solve_triangular(self._L.data, self._L.indices, self._L.indptr, rhs, lower=True)
return solve_triangular(self._LT.data, self._LT.indices, self._LT.indptr, tmp, lower=False)
tmp = solve_triangular(self._L, rhs, lower=True)
return solve_triangular(self._LT, tmp, lower=False)
def get_matrix(self):
return self._L @ self._L.T
@@ -33,6 +30,27 @@ def declare(factory):
factory.declare_precon('ichol', ICholPrecon)
#####################
# wrapper functions #
#####################
def incomplete_cholesky(A):
if not spsp.isspmatrix_csr(A):
raise ValueError('A has to be a sparse matrix in csr format')
L = spsp.tril(A, format='csr')
L.data = incomplete_cholesky_jit(L.data, L.indices, L.indptr)
return L
def solve_triangular(A, b, lower=True):
if not spsp.isspmatrix_csr(A):
raise ValueError('A has to be a sparse matrix in csr format')
return solve_triangular_jit(A.data, A.indices, A.indptr, b, lower)
##########################
# numba helper functions #
##########################
@@ -73,7 +91,7 @@ def idxs2rowscols(indices, indptr):
@njit
def incomplete_cholesky(data, indices, indptr):
def incomplete_cholesky_jit(data, indices, indptr):
L_data = data.copy()
rows, cols = idxs2rowscols(indices, indptr)
@@ -126,7 +144,7 @@ def incomplete_cholesky(data, indices, indptr):
@njit
def solve_triangular(data, indices, indptr, b, lower=True):
def solve_triangular_jit(data, indices, indptr, b, lower=True):
x = np.zeros_like(b)
if lower:
Loading