I'll be disregarding torch, as this doesn't have a lot to do with machine learning specifically. Let's instead develop a demonstration in more generic numerical libraries, Numpy and Scipy. Most of the methods rhyme anyway, so this wouldn't be too difficult to port back to torch.
The output of your methods 1 and 2 are close to within ~1e-9 (but not machine precision). I assume that this is not a problem, and I also assume that you've implemented Sherman-Morrison correctly.
Your second method needs standard vectorisation. After this process, no loops are necessary.
You write that
W is not necessarily a sparse matrix
This matters, and depending on the structure of the matrix you should adopt different methods for the inverse. In particular, your sample matrix is banded, symmetric and Hermitian. As I show below, if you know which of these holds true for any given call, you can make the inverse more efficient. However, in terms of time consumption vectorisation is far and away more important than improving the inverse.
For the banded case, there is setup needed (for which I show a very inefficient method); if your matrix stays banded and L and U are low, again the inverse speeds up.
The benchmark uses a matrix that has the same value range, is larger, still banded-Hermitian and has lower L and U. Again, for your actual data you'll probably need to benchmark again to find which method is most appropriate.
import timeit
import typing
from functools import partial
import numpy as np
import scipy.linalg
def nested_inverse(A: np.ndarray) -> np.ndarray:
"""OP method 1"""
n_node, _ = A.shape
nI = n_node * np.eye(n_node)
C = np.empty((n_node, n_node))
for i in range(n_node):
for j in range(i, n_node):
B = A.copy()
B[i, j] = B[j, i] = 0
C[i, j] = C[j, i] = np.linalg.inv(nI - B)[i, j]
return C
def sherman_morrison_op(A: np.ndarray) -> np.ndarray:
"""OP method 2"""
n_node, _ = A.shape
nI = n_node * np.eye(n_node)
c = np.empty((n_node, n_node))
inv_nI_A = np.linalg.inv(nI - A)
b = np.divide(A, 1 + np.einsum('ij,ij->ij', A, inv_nI_A))
inv_nI_A_ = inv_nI_A[:, np.newaxis, ...]
for i in range(n_node):
for j in range(i, n_node):
c[i, j] = c[j, i] = (
inv_nI_A - b[i, j] * np.matmul(inv_nI_A_[:, :, i], inv_nI_A_[j, :, :])
)[i, j]
return c
def sherman_morrison_mod(a: np.ndarray) -> np.ndarray:
n, _ = a.shape
nI = n*np.eye(n)
inv_nI_A = np.linalg.inv(nI - a) # n,n
invdiag = np.diagonal(inv_nI_A) # n
b = a/(1 + a*inv_nI_A) # n,n
c = inv_nI_A - b*np.outer(invdiag, invdiag) # n,n
return c
def sherman_morrison_scipy(a: np.ndarray) -> np.ndarray:
n, _ = a.shape
nI = n*np.eye(n)
inv_nI_A = scipy.linalg.inv(
a=nI - a, overwrite_a=True, check_finite=False,
) # n,n
invdiag = np.diagonal(inv_nI_A) # n
b = a/(1 + a*inv_nI_A) # n,n
c = inv_nI_A - b*np.outer(invdiag, invdiag) # n,n
return c
def sherman_morrison_structural(
a: np.ndarray,
assume: typing.Literal['gen', 'sym', 'her', 'pos'] = 'gen',
) -> np.ndarray:
n, _ = a.shape
eye = np.eye(n)
inv = scipy.linalg.solve(
a=n*eye - a, overwrite_a=True, assume_a=assume,
b=eye, overwrite_b=True, check_finite=False,
) # n,n
b = a/(1 + a*inv) # n,n
invdiag = np.diagonal(inv) # n
c = inv - b*np.outer(invdiag, invdiag) # n,n
return c
def sherman_morrison_banded(a: np.ndarray, a_banded: np.ndarray, l: int, u: int) -> np.ndarray:
n, _ = a.shape
ab = -a_banded
ab[u] = np.full(shape=n, fill_value=n)
inv = scipy.linalg.solve_banded(
(l, u),
ab=ab, overwrite_ab=True, check_finite=False,
b=np.eye(n), overwrite_b=True,
) # n,n
b = a/(1 + a*inv) # n,n
invdiag = np.diagonal(inv) # n
c = inv - b*np.outer(invdiag, invdiag) # n,n
return c
def make_banded(W: np.ndarray) -> tuple[
np.ndarray,
int, int,
]:
"""Dumb way to construct an array for solve_banded()"""
n, _ = W.shape
for l in range(n-1, -1, -1):
if np.count_nonzero(np.diagonal(W, offset=l)):
break
for u in range(n-1, -1, -1):
if np.count_nonzero(np.diagonal(W, offset=-u)):
break
banded = np.full(shape=(u + 1 + l, n), fill_value=np.nan)
for band_i in range(u+1):
for band_j in range(u - band_i, n):
sq_i = band_i + band_j - u
sq_j = band_j
banded[band_i, band_j] = W[sq_i, sq_j]
for band_i in range(u, u+l+1):
for band_j in range(n - (band_i - u)):
sq_i = band_i + band_j - u
sq_j = band_j
banded[band_i, band_j] = W[sq_i, sq_j]
return banded, l, u
def bind_methods(W: np.ndarray) -> tuple[
typing.Callable[[np.ndarray], np.ndarray], ...
]:
assert scipy.linalg.ishermitian(W)
banded, l, u = make_banded(W)
return (
nested_inverse,
sherman_morrison_op,
sherman_morrison_mod,
sherman_morrison_scipy,
partial(sherman_morrison_structural, assume='gen'),
partial(sherman_morrison_structural, assume='sym'),
partial(sherman_morrison_structural, assume='her'),
# 'pos' excluded; the sample is not positive-definite!
partial(sherman_morrison_banded, l=l, u=u, a_banded=banded),
)
def test() -> None:
W = np.array([
[0, 1, 0, 0, 0, 0, 0, 0, 0],
[1, 0, 1, 0, 0, 1, 0, 0, 0],
[0, 1, 0, 3, 0, 0, 0, 0, 0],
[0, 0, 3, 0, 1, 0, 0, 0, 0],
[0, 0, 0, 1, 0, 1, 1, 0, 0],
[0, 1, 0, 0, 1, 0, 0, 0, 0],
[0, 0, 0, 0, 1, 0, 0, 1, 0],
[0, 0, 0, 0, 0, 0, 1, 0, 1],
[0, 0, 0, 0, 0, 0, 0, 1, 0],
])
methods = iter(bind_methods(W))
op_nested = next(methods)(W)
reference = next(methods)(W)
print(f'Error from nested to Sherman-Morrison: {np.abs(op_nested - reference).max():.1e}')
print()
assert np.allclose(op_nested, reference, rtol=0, atol=1e-8)
for method in methods:
c = method(W)
assert np.allclose(reference, c, rtol=0, atol=1e-16)
def benchmark() -> None:
rand = np.random.default_rng(seed=0)
off_diag = rand.integers(low=0, high=3, size=99)
w = (
np.diag(rand.integers(low=0, high=3, size=100))
+ np.diag(off_diag, k=1)
+ np.diag(off_diag, k=-1)
)
methods = iter(bind_methods(w))
print(next(methods))
print('Too slow to care about')
for method in methods:
print(str(method).splitlines()[0])
n = 5 if '_op' in str(method) else 100
def run():
return method(w)
t = timeit.timeit(run, number=n)
print(f'{t/n*1e3:.2f} ms')
if __name__ == '__main__':
test()
benchmark()
Error from nested to Sherman-Morrison: 3.7e-09
<function nested_inverse at 0x000002A242822FC0>
Too slow to care about
<function sherman_morrison_op at 0x000002A242BDF2E0>
149.03 ms
<function sherman_morrison_mod at 0x000002A273EF0360>
1.36 ms
<function sherman_morrison_scipy at 0x000002A273E30E00>
0.49 ms
functools.partial(<function sherman_morrison_structural at 0x000002A273F765C0>, assume='gen')
0.32 ms
functools.partial(<function sherman_morrison_structural at 0x000002A273F765C0>, assume='sym')
0.61 ms
functools.partial(<function sherman_morrison_structural at 0x000002A273F765C0>, assume='her')
0.38 ms
functools.partial(<function sherman_morrison_banded at 0x000002A273F76660>, l=1, u=1, a_banded=array([[nan, 2., 1., 1., 0., 0., 0., 0., 0., 0., 2., 1., 2.,
0.17 ms