Skip to main content
blurb about benchmark
Source Link
Reinderien
  • 71.2k
  • 5
  • 76
  • 257

The benchmark uses a matrix that has the same value range, is larger, still banded-Hermitian and has lower L and U. Again, for your actual data you'll probably need to benchmark again to find which method is most appropriate.

import timeit
import typing
from functools import partial

import numpy as np
import scipy.linalg


def nested_inverse(A: np.ndarray) -> np.ndarray:
    """OP method 1"""
    n_node, _ = A.shape
    nI = n_node * np.eye(n_node)
    C = np.empty((n_node, n_node))

    for i in range(n_node):
        for j in range(i, n_node):
            B = A.copy()
            B[i, j] = B[j, i] = 0
            C[i, j] = C[j, i] = np.linalg.inv(nI - B)[i, j]
    return C


def sherman_morrison_op(A: np.ndarray) -> np.ndarray:
    """OP method 2"""
    n_node, _ = A.shape
    nI = n_node * np.eye(n_node)
    c = np.empty((n_node, n_node))

    inv_nI_A = np.linalg.inv(nI - A)
    b = np.divide(A, 1 + np.einsum('ij,ij->ij', A, inv_nI_A))
    inv_nI_A_ = inv_nI_A[:, np.newaxis, ...]

    for i in range(n_node):
        for j in range(i, n_node):
            c[i, j] = c[j, i] = (
                inv_nI_A - b[i, j] * np.matmul(inv_nI_A_[:, :, i], inv_nI_A_[j, :, :])
            )[i, j]
    return c


def sherman_morrison_mod(a: np.ndarray) -> np.ndarray:
    n, _ = a.shape
    nI = n*np.eye(n)
    inv_nI_A = np.linalg.inv(nI - a)  # n,n
    invdiag = np.diagonal(inv_nI_A)  # n
    b = a/(1 + a*inv_nI_A)  # n,n
    c = inv_nI_A - b*np.outer(invdiag, invdiag)  # n,n
    return c


def sherman_morrison_scipy(a: np.ndarray) -> np.ndarray:
    n, _ = a.shape
    nI = n*np.eye(n)
    inv_nI_A = scipy.linalg.inv(
        a=nI - a, overwrite_a=True, check_finite=False,
    )  # n,n
    invdiag = np.diagonal(inv_nI_A)  # n
    b = a/(1 + a*inv_nI_A)  # n,n
    c = inv_nI_A - b*np.outer(invdiag, invdiag)  # n,n
    return c


def sherman_morrison_structural(
    a: np.ndarray,
    assume: typing.Literal['gen', 'sym', 'her', 'pos'] = 'gen',
) -> np.ndarray:
    n, _ = a.shape
    eye = np.eye(n)
    inv = scipy.linalg.solve(
        a=n*eye - a, overwrite_a=True, assume_a=assume,
        b=eye, overwrite_b=True, check_finite=False,
    )  # n,n
    b = a/(1 + a*inv)  # n,n
    invdiag = np.diagonal(inv)  # n
    c = inv - b*np.outer(invdiag, invdiag)  # n,n
    return c


def sherman_morrison_banded(a: np.ndarray, a_banded: np.ndarray, l: int, u: int) -> np.ndarray:
    n, _ = a.shape
    ab = -a_banded
    ab[u] = np.full(shape=n, fill_value=n)
    inv = scipy.linalg.solve_banded(
        (l, u),
        ab=ab, overwrite_ab=True, check_finite=False,
        b=np.eye(n), overwrite_b=True,
    )  # n,n
    b = a/(1 + a*inv)  # n,n
    invdiag = np.diagonal(inv)  # n
    c = inv - b*np.outer(invdiag, invdiag)  # n,n
    return c


def make_banded(W: np.ndarray) -> tuple[
    np.ndarray,
    int, int,
]:
    """Dumb way to construct an array for solve_banded()"""
    n, _ = W.shape
    for l in range(n-1, -1, -1):
        if np.count_nonzero(np.diagonal(W, offset=l)):
            break
    for u in range(n-1, -1, -1):
        if np.count_nonzero(np.diagonal(W, offset=-u)):
            break
    banded = np.full(shape=(u + 1 + l, n), fill_value=np.nan)

    for band_i in range(u+1):
        for band_j in range(u - band_i, n):
            sq_i = band_i + band_j - u
            sq_j = band_j
            banded[band_i, band_j] = W[sq_i, sq_j]
    for band_i in range(u, u+l+1):
        for band_j in range(n - (band_i - u)):
            sq_i = band_i + band_j - u
            sq_j = band_j
            banded[band_i, band_j] = W[sq_i, sq_j]
    return banded, l, u


def bind_methods(W: np.ndarray) -> tuple[
    typing.Callable[[np.ndarray], np.ndarray], ...
]:
    assert scipy.linalg.ishermitian(W)
    banded, l, u = make_banded(W)
    return (
        nested_inverse,
        sherman_morrison_op,
        sherman_morrison_mod,
        sherman_morrison_scipy,

        partial(sherman_morrison_structural, assume='gen'),
        partial(sherman_morrison_structural, assume='sym'),
        partial(sherman_morrison_structural, assume='her'),
        # 'pos' excluded; the sample is not positive-definite!

        partial(sherman_morrison_banded, l=l, u=u, a_banded=banded),
    )


def test() -> None:
    W = np.array([
        [0, 1, 0, 0, 0, 0, 0, 0, 0],
        [1, 0, 1, 0, 0, 1, 0, 0, 0],
        [0, 1, 0, 3, 0, 0, 0, 0, 0],
        [0, 0, 3, 0, 1, 0, 0, 0, 0],
        [0, 0, 0, 1, 0, 1, 1, 0, 0],
        [0, 1, 0, 0, 1, 0, 0, 0, 0],
        [0, 0, 0, 0, 1, 0, 0, 1, 0],
        [0, 0, 0, 0, 0, 0, 1, 0, 1],
        [0, 0, 0, 0, 0, 0, 0, 1, 0],
    ])

    methods = iter(bind_methods(W))
    op_nested = next(methods)(W)
    reference = next(methods)(W)
    print(f'Error from nested to Sherman-Morrison: {np.abs(op_nested - reference).max():.1e}')
    print()
    assert np.allclose(op_nested, reference, rtol=0, atol=1e-8)

    for method in methods:
        c = method(W)
        assert np.allclose(reference, c, rtol=0, atol=1e-16)


def benchmark() -> None:
    rand = np.random.default_rng(seed=0)
    off_diag = rand.integers(low=0, high=3, size=99)
    w = (
        np.diag(rand.integers(low=0, high=3, size=100))
        + np.diag(off_diag, k=1)
        + np.diag(off_diag, k=-1)
    )
    methods = iter(bind_methods(w))

    print(next(methods))
    print('Too slow to care about')
    print()

    for method in methods:
        print(str(method).splitlines()[0])
        n = 105 if '_op' in str(method) else 100
        def run():
            return method(w)
        t = timeit.timeit(run, number=n)
        print(f'{t/n*1e3:.2f} ms')
        print()


if __name__ == '__main__':
    test()
    benchmark()
Error from nested to Sherman-Morrison: 3.7e-09

<function nested_inverse at 0x0000021549AA2FC0>0x000002A242822FC0>
Too slow to care about
 
<function sherman_morrison_op at 0x0000021549E632E0>0x000002A242BDF2E0>
155149.9103 ms
 
<function sherman_morrison_mod at 0x000002157B1D0360>0x000002A273EF0360>
21.0136 ms
 
<function sherman_morrison_scipy at 0x000002157B110E00>0x000002A273E30E00>
0.4449 ms
 
functools.partial(<function sherman_morrison_structural at 0x000002157B25A5C0>0x000002A273F765C0>, assume='gen')
0.2832 ms
 
functools.partial(<function sherman_morrison_structural at 0x000002157B25A5C0>0x000002A273F765C0>, assume='sym')
0.5661 ms
 
functools.partial(<function sherman_morrison_structural at 0x000002157B25A5C0>0x000002A273F765C0>, assume='her')
0.3938 ms
 
functools.partial(<function sherman_morrison_banded at 0x000002157B25A660>0x000002A273F76660>, l=1, u=1, a_banded=array([[nan,  2.,  1.,  1.,  0.,  0.,  0.,  0.,  0.,  0.,  2.,  1.,  2.,
0.17 ms
import timeit
import typing
from functools import partial

import numpy as np
import scipy.linalg


def nested_inverse(A: np.ndarray) -> np.ndarray:
    """OP method 1"""
    n_node, _ = A.shape
    nI = n_node * np.eye(n_node)
    C = np.empty((n_node, n_node))

    for i in range(n_node):
        for j in range(i, n_node):
            B = A.copy()
            B[i, j] = B[j, i] = 0
            C[i, j] = C[j, i] = np.linalg.inv(nI - B)[i, j]
    return C


def sherman_morrison_op(A: np.ndarray) -> np.ndarray:
    """OP method 2"""
    n_node, _ = A.shape
    nI = n_node * np.eye(n_node)
    c = np.empty((n_node, n_node))

    inv_nI_A = np.linalg.inv(nI - A)
    b = np.divide(A, 1 + np.einsum('ij,ij->ij', A, inv_nI_A))
    inv_nI_A_ = inv_nI_A[:, np.newaxis, ...]

    for i in range(n_node):
        for j in range(i, n_node):
            c[i, j] = c[j, i] = (
                inv_nI_A - b[i, j] * np.matmul(inv_nI_A_[:, :, i], inv_nI_A_[j, :, :])
            )[i, j]
    return c


def sherman_morrison_mod(a: np.ndarray) -> np.ndarray:
    n, _ = a.shape
    nI = n*np.eye(n)
    inv_nI_A = np.linalg.inv(nI - a)  # n,n
    invdiag = np.diagonal(inv_nI_A)  # n
    b = a/(1 + a*inv_nI_A)  # n,n
    c = inv_nI_A - b*np.outer(invdiag, invdiag)  # n,n
    return c


def sherman_morrison_scipy(a: np.ndarray) -> np.ndarray:
    n, _ = a.shape
    nI = n*np.eye(n)
    inv_nI_A = scipy.linalg.inv(
        a=nI - a, overwrite_a=True, check_finite=False,
    )  # n,n
    invdiag = np.diagonal(inv_nI_A)  # n
    b = a/(1 + a*inv_nI_A)  # n,n
    c = inv_nI_A - b*np.outer(invdiag, invdiag)  # n,n
    return c


def sherman_morrison_structural(
    a: np.ndarray,
    assume: typing.Literal['gen', 'sym', 'her', 'pos'] = 'gen',
) -> np.ndarray:
    n, _ = a.shape
    eye = np.eye(n)
    inv = scipy.linalg.solve(
        a=n*eye - a, overwrite_a=True, assume_a=assume,
        b=eye, overwrite_b=True, check_finite=False,
    )  # n,n
    b = a/(1 + a*inv)  # n,n
    invdiag = np.diagonal(inv)  # n
    c = inv - b*np.outer(invdiag, invdiag)  # n,n
    return c


def sherman_morrison_banded(a: np.ndarray, a_banded: np.ndarray, l: int, u: int) -> np.ndarray:
    n, _ = a.shape
    ab = -a_banded
    ab[u] = np.full(shape=n, fill_value=n)
    inv = scipy.linalg.solve_banded(
        (l, u),
        ab=ab, overwrite_ab=True, check_finite=False,
        b=np.eye(n), overwrite_b=True,
    )  # n,n
    b = a/(1 + a*inv)  # n,n
    invdiag = np.diagonal(inv)  # n
    c = inv - b*np.outer(invdiag, invdiag)  # n,n
    return c


def make_banded(W: np.ndarray) -> tuple[
    np.ndarray,
    int, int,
]:
    """Dumb way to construct an array for solve_banded()"""
    n, _ = W.shape
    for l in range(n-1, -1, -1):
        if np.count_nonzero(np.diagonal(W, offset=l)):
            break
    for u in range(n-1, -1, -1):
        if np.count_nonzero(np.diagonal(W, offset=-u)):
            break
    banded = np.full(shape=(u + 1 + l, n), fill_value=np.nan)

    for band_i in range(u+1):
        for band_j in range(u - band_i, n):
            sq_i = band_i + band_j - u
            sq_j = band_j
            banded[band_i, band_j] = W[sq_i, sq_j]
    for band_i in range(u, u+l+1):
        for band_j in range(n - (band_i - u)):
            sq_i = band_i + band_j - u
            sq_j = band_j
            banded[band_i, band_j] = W[sq_i, sq_j]
    return banded, l, u


def bind_methods(W: np.ndarray) -> tuple[
    typing.Callable[[np.ndarray], np.ndarray], ...
]:
    assert scipy.linalg.ishermitian(W)
    banded, l, u = make_banded(W)
    return (
        nested_inverse,
        sherman_morrison_op,
        sherman_morrison_mod,
        sherman_morrison_scipy,

        partial(sherman_morrison_structural, assume='gen'),
        partial(sherman_morrison_structural, assume='sym'),
        partial(sherman_morrison_structural, assume='her'),
        # 'pos' excluded; the sample is not positive-definite!

        partial(sherman_morrison_banded, l=l, u=u, a_banded=banded),
    )


def test() -> None:
    W = np.array([
        [0, 1, 0, 0, 0, 0, 0, 0, 0],
        [1, 0, 1, 0, 0, 1, 0, 0, 0],
        [0, 1, 0, 3, 0, 0, 0, 0, 0],
        [0, 0, 3, 0, 1, 0, 0, 0, 0],
        [0, 0, 0, 1, 0, 1, 1, 0, 0],
        [0, 1, 0, 0, 1, 0, 0, 0, 0],
        [0, 0, 0, 0, 1, 0, 0, 1, 0],
        [0, 0, 0, 0, 0, 0, 1, 0, 1],
        [0, 0, 0, 0, 0, 0, 0, 1, 0],
    ])

    methods = iter(bind_methods(W))
    op_nested = next(methods)(W)
    reference = next(methods)(W)
    print(f'Error from nested to Sherman-Morrison: {np.abs(op_nested - reference).max():.1e}')
    assert np.allclose(op_nested, reference, rtol=0, atol=1e-8)

    for method in methods:
        c = method(W)
        assert np.allclose(reference, c, rtol=0, atol=1e-16)


def benchmark() -> None:
    rand = np.random.default_rng(seed=0)
    off_diag = rand.integers(low=0, high=3, size=99)
    w = (
        np.diag(rand.integers(low=0, high=3, size=100))
        + np.diag(off_diag, k=1)
        + np.diag(off_diag, k=-1)
    )
    methods = iter(bind_methods(w))

    print(next(methods))
    print('Too slow to care about')
    print()

    for method in methods:
        print(str(method).splitlines()[0])
        n = 10
        def run():
            return method(w)
        t = timeit.timeit(run, number=n)
        print(f'{t/n*1e3:.2f} ms')
        print()


if __name__ == '__main__':
    test()
    benchmark()
Error from nested to Sherman-Morrison: 3.7e-09

<function nested_inverse at 0x0000021549AA2FC0>
Too slow to care about
 
<function sherman_morrison_op at 0x0000021549E632E0>
155.91 ms
 
<function sherman_morrison_mod at 0x000002157B1D0360>
2.01 ms
 
<function sherman_morrison_scipy at 0x000002157B110E00>
0.44 ms
 
functools.partial(<function sherman_morrison_structural at 0x000002157B25A5C0>, assume='gen')
0.28 ms
 
functools.partial(<function sherman_morrison_structural at 0x000002157B25A5C0>, assume='sym')
0.56 ms
 
functools.partial(<function sherman_morrison_structural at 0x000002157B25A5C0>, assume='her')
0.39 ms
 
functools.partial(<function sherman_morrison_banded at 0x000002157B25A660>, l=1, u=1, a_banded=array([[nan,  2.,  1.,  1.,  0.,  0.,  0.,  0.,  0.,  0.,  2.,  1.,  2.,
0.17 ms

The benchmark uses a matrix that has the same value range, is larger, still banded-Hermitian and has lower L and U. Again, for your actual data you'll probably need to benchmark again to find which method is most appropriate.

import timeit
import typing
from functools import partial

import numpy as np
import scipy.linalg


def nested_inverse(A: np.ndarray) -> np.ndarray:
    """OP method 1"""
    n_node, _ = A.shape
    nI = n_node * np.eye(n_node)
    C = np.empty((n_node, n_node))

    for i in range(n_node):
        for j in range(i, n_node):
            B = A.copy()
            B[i, j] = B[j, i] = 0
            C[i, j] = C[j, i] = np.linalg.inv(nI - B)[i, j]
    return C


def sherman_morrison_op(A: np.ndarray) -> np.ndarray:
    """OP method 2"""
    n_node, _ = A.shape
    nI = n_node * np.eye(n_node)
    c = np.empty((n_node, n_node))

    inv_nI_A = np.linalg.inv(nI - A)
    b = np.divide(A, 1 + np.einsum('ij,ij->ij', A, inv_nI_A))
    inv_nI_A_ = inv_nI_A[:, np.newaxis, ...]

    for i in range(n_node):
        for j in range(i, n_node):
            c[i, j] = c[j, i] = (
                inv_nI_A - b[i, j] * np.matmul(inv_nI_A_[:, :, i], inv_nI_A_[j, :, :])
            )[i, j]
    return c


def sherman_morrison_mod(a: np.ndarray) -> np.ndarray:
    n, _ = a.shape
    nI = n*np.eye(n)
    inv_nI_A = np.linalg.inv(nI - a)  # n,n
    invdiag = np.diagonal(inv_nI_A)  # n
    b = a/(1 + a*inv_nI_A)  # n,n
    c = inv_nI_A - b*np.outer(invdiag, invdiag)  # n,n
    return c


def sherman_morrison_scipy(a: np.ndarray) -> np.ndarray:
    n, _ = a.shape
    nI = n*np.eye(n)
    inv_nI_A = scipy.linalg.inv(
        a=nI - a, overwrite_a=True, check_finite=False,
    )  # n,n
    invdiag = np.diagonal(inv_nI_A)  # n
    b = a/(1 + a*inv_nI_A)  # n,n
    c = inv_nI_A - b*np.outer(invdiag, invdiag)  # n,n
    return c


def sherman_morrison_structural(
    a: np.ndarray,
    assume: typing.Literal['gen', 'sym', 'her', 'pos'] = 'gen',
) -> np.ndarray:
    n, _ = a.shape
    eye = np.eye(n)
    inv = scipy.linalg.solve(
        a=n*eye - a, overwrite_a=True, assume_a=assume,
        b=eye, overwrite_b=True, check_finite=False,
    )  # n,n
    b = a/(1 + a*inv)  # n,n
    invdiag = np.diagonal(inv)  # n
    c = inv - b*np.outer(invdiag, invdiag)  # n,n
    return c


def sherman_morrison_banded(a: np.ndarray, a_banded: np.ndarray, l: int, u: int) -> np.ndarray:
    n, _ = a.shape
    ab = -a_banded
    ab[u] = np.full(shape=n, fill_value=n)
    inv = scipy.linalg.solve_banded(
        (l, u),
        ab=ab, overwrite_ab=True, check_finite=False,
        b=np.eye(n), overwrite_b=True,
    )  # n,n
    b = a/(1 + a*inv)  # n,n
    invdiag = np.diagonal(inv)  # n
    c = inv - b*np.outer(invdiag, invdiag)  # n,n
    return c


def make_banded(W: np.ndarray) -> tuple[
    np.ndarray,
    int, int,
]:
    """Dumb way to construct an array for solve_banded()"""
    n, _ = W.shape
    for l in range(n-1, -1, -1):
        if np.count_nonzero(np.diagonal(W, offset=l)):
            break
    for u in range(n-1, -1, -1):
        if np.count_nonzero(np.diagonal(W, offset=-u)):
            break
    banded = np.full(shape=(u + 1 + l, n), fill_value=np.nan)

    for band_i in range(u+1):
        for band_j in range(u - band_i, n):
            sq_i = band_i + band_j - u
            sq_j = band_j
            banded[band_i, band_j] = W[sq_i, sq_j]
    for band_i in range(u, u+l+1):
        for band_j in range(n - (band_i - u)):
            sq_i = band_i + band_j - u
            sq_j = band_j
            banded[band_i, band_j] = W[sq_i, sq_j]
    return banded, l, u


def bind_methods(W: np.ndarray) -> tuple[
    typing.Callable[[np.ndarray], np.ndarray], ...
]:
    assert scipy.linalg.ishermitian(W)
    banded, l, u = make_banded(W)
    return (
        nested_inverse,
        sherman_morrison_op,
        sherman_morrison_mod,
        sherman_morrison_scipy,

        partial(sherman_morrison_structural, assume='gen'),
        partial(sherman_morrison_structural, assume='sym'),
        partial(sherman_morrison_structural, assume='her'),
        # 'pos' excluded; the sample is not positive-definite!

        partial(sherman_morrison_banded, l=l, u=u, a_banded=banded),
    )


def test() -> None:
    W = np.array([
        [0, 1, 0, 0, 0, 0, 0, 0, 0],
        [1, 0, 1, 0, 0, 1, 0, 0, 0],
        [0, 1, 0, 3, 0, 0, 0, 0, 0],
        [0, 0, 3, 0, 1, 0, 0, 0, 0],
        [0, 0, 0, 1, 0, 1, 1, 0, 0],
        [0, 1, 0, 0, 1, 0, 0, 0, 0],
        [0, 0, 0, 0, 1, 0, 0, 1, 0],
        [0, 0, 0, 0, 0, 0, 1, 0, 1],
        [0, 0, 0, 0, 0, 0, 0, 1, 0],
    ])

    methods = iter(bind_methods(W))
    op_nested = next(methods)(W)
    reference = next(methods)(W)
    print(f'Error from nested to Sherman-Morrison: {np.abs(op_nested - reference).max():.1e}')
    print()
    assert np.allclose(op_nested, reference, rtol=0, atol=1e-8)

    for method in methods:
        c = method(W)
        assert np.allclose(reference, c, rtol=0, atol=1e-16)


def benchmark() -> None:
    rand = np.random.default_rng(seed=0)
    off_diag = rand.integers(low=0, high=3, size=99)
    w = (
        np.diag(rand.integers(low=0, high=3, size=100))
        + np.diag(off_diag, k=1)
        + np.diag(off_diag, k=-1)
    )
    methods = iter(bind_methods(w))

    print(next(methods))
    print('Too slow to care about')

    for method in methods:
        print(str(method).splitlines()[0])
        n = 5 if '_op' in str(method) else 100
        def run():
            return method(w)
        t = timeit.timeit(run, number=n)
        print(f'{t/n*1e3:.2f} ms')


if __name__ == '__main__':
    test()
    benchmark()
Error from nested to Sherman-Morrison: 3.7e-09

<function nested_inverse at 0x000002A242822FC0>
Too slow to care about
<function sherman_morrison_op at 0x000002A242BDF2E0>
149.03 ms
<function sherman_morrison_mod at 0x000002A273EF0360>
1.36 ms
<function sherman_morrison_scipy at 0x000002A273E30E00>
0.49 ms
functools.partial(<function sherman_morrison_structural at 0x000002A273F765C0>, assume='gen')
0.32 ms
functools.partial(<function sherman_morrison_structural at 0x000002A273F765C0>, assume='sym')
0.61 ms
functools.partial(<function sherman_morrison_structural at 0x000002A273F765C0>, assume='her')
0.38 ms
functools.partial(<function sherman_morrison_banded at 0x000002A273F76660>, l=1, u=1, a_banded=array([[nan,  2.,  1.,  1.,  0.,  0.,  0.,  0.,  0.,  0.,  2.,  1.,  2.,
0.17 ms
add scipy inv
Source Link
Reinderien
  • 71.2k
  • 5
  • 76
  • 257
import timeit
import typing
from functools import partial

import numpy as np
import scipy.linalg


def nested_inverse(A: np.ndarray) -> np.ndarray:
    """OP method 1"""
    n_node, _ = A.shape
    nI = n_node * np.eye(n_node)
    C = np.empty((n_node, n_node))

    for i in range(n_node):
        for j in range(i, n_node):
            B = A.copy()
            B[i, j] = B[j, i] = 0
            C[i, j] = C[j, i] = np.linalg.inv(nI - B)[i, j]
    return C


def sherman_morrison_op(A: np.ndarray) -> np.ndarray:
    """OP method 2"""
    n_node, _ = A.shape
    nI = n_node * np.eye(n_node)
    c = np.empty((n_node, n_node))

    inv_nI_A = np.linalg.inv(nI - A)
    b = np.divide(A, 1 + np.einsum('ij,ij->ij', A, inv_nI_A))
    inv_nI_A_ = inv_nI_A[:, np.newaxis, ...]

    for i in range(n_node):
        for j in range(i, n_node):
            c[i, j] = c[j, i] = (
                inv_nI_A - b[i, j] * np.matmul(inv_nI_A_[:, :, i], inv_nI_A_[j, :, :])
            )[i, j]
    return c


def sherman_morrison_mod(a: np.ndarray) -> np.ndarray:
    n, _ = a.shape
    nI = n*np.eye(n)
    inv_nI_A = np.linalg.inv(nI - a)  # n,n
    invdiag = np.diagonal(inv_nI_A)  # n
    b = a/(1 + a*inv_nI_A)  # n,n
    c = inv_nI_A - b*np.outer(invdiag, invdiag)  # n,n
    return c


def sherman_morrison_scipy(a: np.ndarray) -> np.ndarray:
    n, _ = a.shape
    nI = n*np.eye(n)
    inv_nI_A = scipy.linalg.inv(
        a=nI - a, overwrite_a=True, check_finite=False,
    )  # n,n
    invdiag = np.diagonal(inv_nI_A)  # n
    b = a/(1 + a*inv_nI_A)  # n,n
    c = inv_nI_A - b*np.outer(invdiag, invdiag)  # n,n
    return c


def sherman_morrison_structural(
    a: np.ndarray,
    assume: typing.Literal['gen', 'sym', 'her', 'pos'] = 'gen',
) -> np.ndarray:
    n, _ = a.shape
    eye = np.eye(n)
    inv = scipy.linalg.solve(
        a=n*eye - a, overwrite_a=True, assume_a=assume,
        b=eye, overwrite_b=True, check_finite=False,
    )  # n,n
    b = a/(1 + a*inv)  # n,n
    invdiag = np.diagonal(inv)  # n
    c = inv - b*np.outer(invdiag, invdiag)  # n,n
    return c


def sherman_morrison_banded(a: np.ndarray, a_banded: np.ndarray, l: int, u: int) -> np.ndarray:
    n, _ = a.shape
    ab = -a_banded
    ab[u] = np.full(shape=n, fill_value=n)
    inv = scipy.linalg.solve_banded(
        (l, u),
        ab=ab, overwrite_ab=True, check_finite=False,
        b=np.eye(n), overwrite_b=True,
    )  # n,n
    b = a/(1 + a*inv)  # n,n
    invdiag = np.diagonal(inv)  # n
    c = inv - b*np.outer(invdiag, invdiag)  # n,n
    return c


def make_banded(W: np.ndarray) -> tuple[
    np.ndarray,
    int, int,
]:
    """Dumb way to construct an array for solve_banded()"""
    n, _ = W.shape
    for l in range(n-1, -1, -1):
        if np.count_nonzero(np.diagonal(W, offset=l)):
            break
    for u in range(n-1, -1, -1):
        if np.count_nonzero(np.diagonal(W, offset=-u)):
            break
    banded = np.full(shape=(u + 1 + l, n), fill_value=np.nan)

    for band_i in range(u+1):
        for band_j in range(u - band_i, n):
            sq_i = band_i + band_j - u
            sq_j = band_j
            banded[band_i, band_j] = W[sq_i, sq_j]
    for band_i in range(u, u+l+1):
        for band_j in range(n - (band_i - u)):
            sq_i = band_i + band_j - u
            sq_j = band_j
            banded[band_i, band_j] = W[sq_i, sq_j]
    return banded, l, u


def bind_methods(W: np.ndarray) -> tuple[
    typing.Callable[[np.ndarray], np.ndarray], ...
]:
    assert scipy.linalg.ishermitian(W)
    banded, l, u = make_banded(W)
    return (
        nested_inverse,
        sherman_morrison_op,
        sherman_morrison_mod,
        sherman_morrison_scipy,

        partial(sherman_morrison_structural, assume='gen'),
        partial(sherman_morrison_structural, assume='sym'),
        partial(sherman_morrison_structural, assume='her'),
        # 'pos' excluded; the sample is not positive-definite!

        partial(sherman_morrison_banded, l=l, u=u, a_banded=banded),
    )


def test() -> None:
    W = np.array([
        [0, 1, 0, 0, 0, 0, 0, 0, 0],
        [1, 0, 1, 0, 0, 1, 0, 0, 0],
        [0, 1, 0, 3, 0, 0, 0, 0, 0],
        [0, 0, 3, 0, 1, 0, 0, 0, 0],
        [0, 0, 0, 1, 0, 1, 1, 0, 0],
        [0, 1, 0, 0, 1, 0, 0, 0, 0],
        [0, 0, 0, 0, 1, 0, 0, 1, 0],
        [0, 0, 0, 0, 0, 0, 1, 0, 1],
        [0, 0, 0, 0, 0, 0, 0, 1, 0],
    ])

    methods = iter(bind_methods(W))
    op_nested = next(methods)(W)
    reference = next(methods)(W)
    print(f'Error from nested to Sherman-Morrison: {np.abs(op_nested - reference).max():.1e}')
    assert np.allclose(op_nested, reference, rtol=0, atol=1e-8)

    for method in methods:
        c = method(W)
        assert np.allclose(reference, c, rtol=0, atol=1e-16)


def benchmark() -> None:
    rand = np.random.default_rng(seed=0)
    off_diag = rand.integers(low=0, high=3, size=99)
    w = (
        np.diag(rand.integers(low=0, high=3, size=100))
        + np.diag(off_diag, k=1)
        + np.diag(off_diag, k=-1)
    )
    methods = iter(bind_methods(w))

    print(next(methods))
    print('Too slow to care about')
    print()

    for method in methods:
        print(str(method).splitlines()[0])
        n = 10
        def run():
            return method(w)
        t = timeit.timeit(run, number=n)
        print(f'{t/n*1e3:.2f} ms')
        print()


if __name__ == '__main__':
    test()
    benchmark()
Error from nested to Sherman-Morrison: 3.7e-09

<function nested_inverse at 0x000001EE1BEA2FC0>0x0000021549AA2FC0>
Too slow to care about

<function sherman_morrison_op at 0x000001EE1C25F240>0x0000021549E632E0>
269155.2891 ms

<function sherman_morrison_mod at 0x000001EE4D5302C0>0x000002157B1D0360>
2.6001 ms

<function sherman_morrison_scipy at 0x000002157B110E00>
0.44 ms

functools.partial(<function sherman_morrison_structural at 0x000001EE4D470D60>0x000002157B25A5C0>, assume='gen')
0.3228 ms

functools.partial(<function sherman_morrison_structural at 0x000001EE4D470D60>0x000002157B25A5C0>, assume='sym')
0.4256 ms

functools.partial(<function sherman_morrison_structural at 0x000001EE4D470D60>0x000002157B25A5C0>, assume='her')
0.3739 ms

functools.partial(<function sherman_morrison_banded at 0x000001EE4D5B6520>0x000002157B25A660>, l=1, u=1, a_banded=array([[nan,  2.,  1.,  1.,  0.,  0.,  0.,  0.,  0.,  0.,  2.,  1.,  2.,
0.1817 ms
import timeit
import typing
from functools import partial

import numpy as np
import scipy.linalg


def nested_inverse(A: np.ndarray) -> np.ndarray:
    """OP method 1"""
    n_node, _ = A.shape
    nI = n_node * np.eye(n_node)
    C = np.empty((n_node, n_node))

    for i in range(n_node):
        for j in range(i, n_node):
            B = A.copy()
            B[i, j] = B[j, i] = 0
            C[i, j] = C[j, i] = np.linalg.inv(nI - B)[i, j]
    return C


def sherman_morrison_op(A: np.ndarray) -> np.ndarray:
    """OP method 2"""
    n_node, _ = A.shape
    nI = n_node * np.eye(n_node)
    c = np.empty((n_node, n_node))

    inv_nI_A = np.linalg.inv(nI - A)
    b = np.divide(A, 1 + np.einsum('ij,ij->ij', A, inv_nI_A))
    inv_nI_A_ = inv_nI_A[:, np.newaxis, ...]

    for i in range(n_node):
        for j in range(i, n_node):
            c[i, j] = c[j, i] = (
                inv_nI_A - b[i, j] * np.matmul(inv_nI_A_[:, :, i], inv_nI_A_[j, :, :])
            )[i, j]
    return c


def sherman_morrison_mod(a: np.ndarray) -> np.ndarray:
    n, _ = a.shape
    nI = n*np.eye(n)
    inv_nI_A = np.linalg.inv(nI - a)  # n,n
    invdiag = np.diagonal(inv_nI_A)  # n
    b = a/(1 + a*inv_nI_A)  # n,n
    c = inv_nI_A - b*np.outer(invdiag, invdiag)  # n,n
    return c


def sherman_morrison_structural(
    a: np.ndarray,
    assume: typing.Literal['gen', 'sym', 'her', 'pos'] = 'gen',
) -> np.ndarray:
    n, _ = a.shape
    eye = np.eye(n)
    inv = scipy.linalg.solve(
        a=n*eye - a, overwrite_a=True, assume_a=assume,
        b=eye, overwrite_b=True, check_finite=False,
    )  # n,n
    b = a/(1 + a*inv)  # n,n
    invdiag = np.diagonal(inv)  # n
    c = inv - b*np.outer(invdiag, invdiag)  # n,n
    return c


def sherman_morrison_banded(a: np.ndarray, a_banded: np.ndarray, l: int, u: int) -> np.ndarray:
    n, _ = a.shape
    ab = -a_banded
    ab[u] = np.full(shape=n, fill_value=n)
    inv = scipy.linalg.solve_banded(
        (l, u),
        ab=ab, overwrite_ab=True, check_finite=False,
        b=np.eye(n), overwrite_b=True,
    )  # n,n
    b = a/(1 + a*inv)  # n,n
    invdiag = np.diagonal(inv)  # n
    c = inv - b*np.outer(invdiag, invdiag)  # n,n
    return c


def make_banded(W: np.ndarray) -> tuple[
    np.ndarray,
    int, int,
]:
    """Dumb way to construct an array for solve_banded()"""
    n, _ = W.shape
    for l in range(n-1, -1, -1):
        if np.count_nonzero(np.diagonal(W, offset=l)):
            break
    for u in range(n-1, -1, -1):
        if np.count_nonzero(np.diagonal(W, offset=-u)):
            break
    banded = np.full(shape=(u + 1 + l, n), fill_value=np.nan)

    for band_i in range(u+1):
        for band_j in range(u - band_i, n):
            sq_i = band_i + band_j - u
            sq_j = band_j
            banded[band_i, band_j] = W[sq_i, sq_j]
    for band_i in range(u, u+l+1):
        for band_j in range(n - (band_i - u)):
            sq_i = band_i + band_j - u
            sq_j = band_j
            banded[band_i, band_j] = W[sq_i, sq_j]
    return banded, l, u


def bind_methods(W: np.ndarray) -> tuple[
    typing.Callable[[np.ndarray], np.ndarray], ...
]:
    assert scipy.linalg.ishermitian(W)
    banded, l, u = make_banded(W)
    return (
        nested_inverse,
        sherman_morrison_op,
        sherman_morrison_mod,

        partial(sherman_morrison_structural, assume='gen'),
        partial(sherman_morrison_structural, assume='sym'),
        partial(sherman_morrison_structural, assume='her'),
        # 'pos' excluded; the sample is not positive-definite!

        partial(sherman_morrison_banded, l=l, u=u, a_banded=banded),
    )


def test() -> None:
    W = np.array([
        [0, 1, 0, 0, 0, 0, 0, 0, 0],
        [1, 0, 1, 0, 0, 1, 0, 0, 0],
        [0, 1, 0, 3, 0, 0, 0, 0, 0],
        [0, 0, 3, 0, 1, 0, 0, 0, 0],
        [0, 0, 0, 1, 0, 1, 1, 0, 0],
        [0, 1, 0, 0, 1, 0, 0, 0, 0],
        [0, 0, 0, 0, 1, 0, 0, 1, 0],
        [0, 0, 0, 0, 0, 0, 1, 0, 1],
        [0, 0, 0, 0, 0, 0, 0, 1, 0],
    ])

    methods = iter(bind_methods(W))
    op_nested = next(methods)(W)
    reference = next(methods)(W)
    print(f'Error from nested to Sherman-Morrison: {np.abs(op_nested - reference).max():.1e}')
    assert np.allclose(op_nested, reference, rtol=0, atol=1e-8)

    for method in methods:
        c = method(W)
        assert np.allclose(reference, c, rtol=0, atol=1e-16)


def benchmark() -> None:
    rand = np.random.default_rng(seed=0)
    off_diag = rand.integers(low=0, high=3, size=99)
    w = (
        np.diag(rand.integers(low=0, high=3, size=100))
        + np.diag(off_diag, k=1)
        + np.diag(off_diag, k=-1)
    )
    methods = iter(bind_methods(w))

    print(next(methods))
    print('Too slow to care about')
    print()

    for method in methods:
        print(str(method).splitlines()[0])
        n = 10
        def run():
            return method(w)
        t = timeit.timeit(run, number=n)
        print(f'{t/n*1e3:.2f} ms')
        print()


if __name__ == '__main__':
    test()
    benchmark()
Error from nested to Sherman-Morrison: 3.7e-09

<function nested_inverse at 0x000001EE1BEA2FC0>
Too slow to care about

<function sherman_morrison_op at 0x000001EE1C25F240>
269.28 ms

<function sherman_morrison_mod at 0x000001EE4D5302C0>
2.60 ms

functools.partial(<function sherman_morrison_structural at 0x000001EE4D470D60>, assume='gen')
0.32 ms

functools.partial(<function sherman_morrison_structural at 0x000001EE4D470D60>, assume='sym')
0.42 ms

functools.partial(<function sherman_morrison_structural at 0x000001EE4D470D60>, assume='her')
0.37 ms

functools.partial(<function sherman_morrison_banded at 0x000001EE4D5B6520>, l=1, u=1, a_banded=array([[nan,  2.,  1.,  1.,  0.,  0.,  0.,  0.,  0.,  0.,  2.,  1.,  2.,
0.18 ms
import timeit
import typing
from functools import partial

import numpy as np
import scipy.linalg


def nested_inverse(A: np.ndarray) -> np.ndarray:
    """OP method 1"""
    n_node, _ = A.shape
    nI = n_node * np.eye(n_node)
    C = np.empty((n_node, n_node))

    for i in range(n_node):
        for j in range(i, n_node):
            B = A.copy()
            B[i, j] = B[j, i] = 0
            C[i, j] = C[j, i] = np.linalg.inv(nI - B)[i, j]
    return C


def sherman_morrison_op(A: np.ndarray) -> np.ndarray:
    """OP method 2"""
    n_node, _ = A.shape
    nI = n_node * np.eye(n_node)
    c = np.empty((n_node, n_node))

    inv_nI_A = np.linalg.inv(nI - A)
    b = np.divide(A, 1 + np.einsum('ij,ij->ij', A, inv_nI_A))
    inv_nI_A_ = inv_nI_A[:, np.newaxis, ...]

    for i in range(n_node):
        for j in range(i, n_node):
            c[i, j] = c[j, i] = (
                inv_nI_A - b[i, j] * np.matmul(inv_nI_A_[:, :, i], inv_nI_A_[j, :, :])
            )[i, j]
    return c


def sherman_morrison_mod(a: np.ndarray) -> np.ndarray:
    n, _ = a.shape
    nI = n*np.eye(n)
    inv_nI_A = np.linalg.inv(nI - a)  # n,n
    invdiag = np.diagonal(inv_nI_A)  # n
    b = a/(1 + a*inv_nI_A)  # n,n
    c = inv_nI_A - b*np.outer(invdiag, invdiag)  # n,n
    return c


def sherman_morrison_scipy(a: np.ndarray) -> np.ndarray:
    n, _ = a.shape
    nI = n*np.eye(n)
    inv_nI_A = scipy.linalg.inv(
        a=nI - a, overwrite_a=True, check_finite=False,
    )  # n,n
    invdiag = np.diagonal(inv_nI_A)  # n
    b = a/(1 + a*inv_nI_A)  # n,n
    c = inv_nI_A - b*np.outer(invdiag, invdiag)  # n,n
    return c


def sherman_morrison_structural(
    a: np.ndarray,
    assume: typing.Literal['gen', 'sym', 'her', 'pos'] = 'gen',
) -> np.ndarray:
    n, _ = a.shape
    eye = np.eye(n)
    inv = scipy.linalg.solve(
        a=n*eye - a, overwrite_a=True, assume_a=assume,
        b=eye, overwrite_b=True, check_finite=False,
    )  # n,n
    b = a/(1 + a*inv)  # n,n
    invdiag = np.diagonal(inv)  # n
    c = inv - b*np.outer(invdiag, invdiag)  # n,n
    return c


def sherman_morrison_banded(a: np.ndarray, a_banded: np.ndarray, l: int, u: int) -> np.ndarray:
    n, _ = a.shape
    ab = -a_banded
    ab[u] = np.full(shape=n, fill_value=n)
    inv = scipy.linalg.solve_banded(
        (l, u),
        ab=ab, overwrite_ab=True, check_finite=False,
        b=np.eye(n), overwrite_b=True,
    )  # n,n
    b = a/(1 + a*inv)  # n,n
    invdiag = np.diagonal(inv)  # n
    c = inv - b*np.outer(invdiag, invdiag)  # n,n
    return c


def make_banded(W: np.ndarray) -> tuple[
    np.ndarray,
    int, int,
]:
    """Dumb way to construct an array for solve_banded()"""
    n, _ = W.shape
    for l in range(n-1, -1, -1):
        if np.count_nonzero(np.diagonal(W, offset=l)):
            break
    for u in range(n-1, -1, -1):
        if np.count_nonzero(np.diagonal(W, offset=-u)):
            break
    banded = np.full(shape=(u + 1 + l, n), fill_value=np.nan)

    for band_i in range(u+1):
        for band_j in range(u - band_i, n):
            sq_i = band_i + band_j - u
            sq_j = band_j
            banded[band_i, band_j] = W[sq_i, sq_j]
    for band_i in range(u, u+l+1):
        for band_j in range(n - (band_i - u)):
            sq_i = band_i + band_j - u
            sq_j = band_j
            banded[band_i, band_j] = W[sq_i, sq_j]
    return banded, l, u


def bind_methods(W: np.ndarray) -> tuple[
    typing.Callable[[np.ndarray], np.ndarray], ...
]:
    assert scipy.linalg.ishermitian(W)
    banded, l, u = make_banded(W)
    return (
        nested_inverse,
        sherman_morrison_op,
        sherman_morrison_mod,
        sherman_morrison_scipy,

        partial(sherman_morrison_structural, assume='gen'),
        partial(sherman_morrison_structural, assume='sym'),
        partial(sherman_morrison_structural, assume='her'),
        # 'pos' excluded; the sample is not positive-definite!

        partial(sherman_morrison_banded, l=l, u=u, a_banded=banded),
    )


def test() -> None:
    W = np.array([
        [0, 1, 0, 0, 0, 0, 0, 0, 0],
        [1, 0, 1, 0, 0, 1, 0, 0, 0],
        [0, 1, 0, 3, 0, 0, 0, 0, 0],
        [0, 0, 3, 0, 1, 0, 0, 0, 0],
        [0, 0, 0, 1, 0, 1, 1, 0, 0],
        [0, 1, 0, 0, 1, 0, 0, 0, 0],
        [0, 0, 0, 0, 1, 0, 0, 1, 0],
        [0, 0, 0, 0, 0, 0, 1, 0, 1],
        [0, 0, 0, 0, 0, 0, 0, 1, 0],
    ])

    methods = iter(bind_methods(W))
    op_nested = next(methods)(W)
    reference = next(methods)(W)
    print(f'Error from nested to Sherman-Morrison: {np.abs(op_nested - reference).max():.1e}')
    assert np.allclose(op_nested, reference, rtol=0, atol=1e-8)

    for method in methods:
        c = method(W)
        assert np.allclose(reference, c, rtol=0, atol=1e-16)


def benchmark() -> None:
    rand = np.random.default_rng(seed=0)
    off_diag = rand.integers(low=0, high=3, size=99)
    w = (
        np.diag(rand.integers(low=0, high=3, size=100))
        + np.diag(off_diag, k=1)
        + np.diag(off_diag, k=-1)
    )
    methods = iter(bind_methods(w))

    print(next(methods))
    print('Too slow to care about')
    print()

    for method in methods:
        print(str(method).splitlines()[0])
        n = 10
        def run():
            return method(w)
        t = timeit.timeit(run, number=n)
        print(f'{t/n*1e3:.2f} ms')
        print()


if __name__ == '__main__':
    test()
    benchmark()
Error from nested to Sherman-Morrison: 3.7e-09

<function nested_inverse at 0x0000021549AA2FC0>
Too slow to care about

<function sherman_morrison_op at 0x0000021549E632E0>
155.91 ms

<function sherman_morrison_mod at 0x000002157B1D0360>
2.01 ms

<function sherman_morrison_scipy at 0x000002157B110E00>
0.44 ms

functools.partial(<function sherman_morrison_structural at 0x000002157B25A5C0>, assume='gen')
0.28 ms

functools.partial(<function sherman_morrison_structural at 0x000002157B25A5C0>, assume='sym')
0.56 ms

functools.partial(<function sherman_morrison_structural at 0x000002157B25A5C0>, assume='her')
0.39 ms

functools.partial(<function sherman_morrison_banded at 0x000002157B25A660>, l=1, u=1, a_banded=array([[nan,  2.,  1.,  1.,  0.,  0.,  0.,  0.,  0.,  0.,  2.,  1.,  2.,
0.17 ms
Source Link
Reinderien
  • 71.2k
  • 5
  • 76
  • 257

I'll be disregarding torch, as this doesn't have a lot to do with machine learning specifically. Let's instead develop a demonstration in more generic numerical libraries, Numpy and Scipy. Most of the methods rhyme anyway, so this wouldn't be too difficult to port back to torch.

The output of your methods 1 and 2 are close to within ~1e-9 (but not machine precision). I assume that this is not a problem, and I also assume that you've implemented Sherman-Morrison correctly.

Your second method needs standard vectorisation. After this process, no loops are necessary.

You write that

W is not necessarily a sparse matrix

This matters, and depending on the structure of the matrix you should adopt different methods for the inverse. In particular, your sample matrix is banded, symmetric and Hermitian. As I show below, if you know which of these holds true for any given call, you can make the inverse more efficient. However, in terms of time consumption vectorisation is far and away more important than improving the inverse.

For the banded case, there is setup needed (for which I show a very inefficient method); if your matrix stays banded and L and U are low, again the inverse speeds up.

import timeit
import typing
from functools import partial

import numpy as np
import scipy.linalg


def nested_inverse(A: np.ndarray) -> np.ndarray:
    """OP method 1"""
    n_node, _ = A.shape
    nI = n_node * np.eye(n_node)
    C = np.empty((n_node, n_node))

    for i in range(n_node):
        for j in range(i, n_node):
            B = A.copy()
            B[i, j] = B[j, i] = 0
            C[i, j] = C[j, i] = np.linalg.inv(nI - B)[i, j]
    return C


def sherman_morrison_op(A: np.ndarray) -> np.ndarray:
    """OP method 2"""
    n_node, _ = A.shape
    nI = n_node * np.eye(n_node)
    c = np.empty((n_node, n_node))

    inv_nI_A = np.linalg.inv(nI - A)
    b = np.divide(A, 1 + np.einsum('ij,ij->ij', A, inv_nI_A))
    inv_nI_A_ = inv_nI_A[:, np.newaxis, ...]

    for i in range(n_node):
        for j in range(i, n_node):
            c[i, j] = c[j, i] = (
                inv_nI_A - b[i, j] * np.matmul(inv_nI_A_[:, :, i], inv_nI_A_[j, :, :])
            )[i, j]
    return c


def sherman_morrison_mod(a: np.ndarray) -> np.ndarray:
    n, _ = a.shape
    nI = n*np.eye(n)
    inv_nI_A = np.linalg.inv(nI - a)  # n,n
    invdiag = np.diagonal(inv_nI_A)  # n
    b = a/(1 + a*inv_nI_A)  # n,n
    c = inv_nI_A - b*np.outer(invdiag, invdiag)  # n,n
    return c


def sherman_morrison_structural(
    a: np.ndarray,
    assume: typing.Literal['gen', 'sym', 'her', 'pos'] = 'gen',
) -> np.ndarray:
    n, _ = a.shape
    eye = np.eye(n)
    inv = scipy.linalg.solve(
        a=n*eye - a, overwrite_a=True, assume_a=assume,
        b=eye, overwrite_b=True, check_finite=False,
    )  # n,n
    b = a/(1 + a*inv)  # n,n
    invdiag = np.diagonal(inv)  # n
    c = inv - b*np.outer(invdiag, invdiag)  # n,n
    return c


def sherman_morrison_banded(a: np.ndarray, a_banded: np.ndarray, l: int, u: int) -> np.ndarray:
    n, _ = a.shape
    ab = -a_banded
    ab[u] = np.full(shape=n, fill_value=n)
    inv = scipy.linalg.solve_banded(
        (l, u),
        ab=ab, overwrite_ab=True, check_finite=False,
        b=np.eye(n), overwrite_b=True,
    )  # n,n
    b = a/(1 + a*inv)  # n,n
    invdiag = np.diagonal(inv)  # n
    c = inv - b*np.outer(invdiag, invdiag)  # n,n
    return c


def make_banded(W: np.ndarray) -> tuple[
    np.ndarray,
    int, int,
]:
    """Dumb way to construct an array for solve_banded()"""
    n, _ = W.shape
    for l in range(n-1, -1, -1):
        if np.count_nonzero(np.diagonal(W, offset=l)):
            break
    for u in range(n-1, -1, -1):
        if np.count_nonzero(np.diagonal(W, offset=-u)):
            break
    banded = np.full(shape=(u + 1 + l, n), fill_value=np.nan)

    for band_i in range(u+1):
        for band_j in range(u - band_i, n):
            sq_i = band_i + band_j - u
            sq_j = band_j
            banded[band_i, band_j] = W[sq_i, sq_j]
    for band_i in range(u, u+l+1):
        for band_j in range(n - (band_i - u)):
            sq_i = band_i + band_j - u
            sq_j = band_j
            banded[band_i, band_j] = W[sq_i, sq_j]
    return banded, l, u


def bind_methods(W: np.ndarray) -> tuple[
    typing.Callable[[np.ndarray], np.ndarray], ...
]:
    assert scipy.linalg.ishermitian(W)
    banded, l, u = make_banded(W)
    return (
        nested_inverse,
        sherman_morrison_op,
        sherman_morrison_mod,

        partial(sherman_morrison_structural, assume='gen'),
        partial(sherman_morrison_structural, assume='sym'),
        partial(sherman_morrison_structural, assume='her'),
        # 'pos' excluded; the sample is not positive-definite!

        partial(sherman_morrison_banded, l=l, u=u, a_banded=banded),
    )


def test() -> None:
    W = np.array([
        [0, 1, 0, 0, 0, 0, 0, 0, 0],
        [1, 0, 1, 0, 0, 1, 0, 0, 0],
        [0, 1, 0, 3, 0, 0, 0, 0, 0],
        [0, 0, 3, 0, 1, 0, 0, 0, 0],
        [0, 0, 0, 1, 0, 1, 1, 0, 0],
        [0, 1, 0, 0, 1, 0, 0, 0, 0],
        [0, 0, 0, 0, 1, 0, 0, 1, 0],
        [0, 0, 0, 0, 0, 0, 1, 0, 1],
        [0, 0, 0, 0, 0, 0, 0, 1, 0],
    ])

    methods = iter(bind_methods(W))
    op_nested = next(methods)(W)
    reference = next(methods)(W)
    print(f'Error from nested to Sherman-Morrison: {np.abs(op_nested - reference).max():.1e}')
    assert np.allclose(op_nested, reference, rtol=0, atol=1e-8)

    for method in methods:
        c = method(W)
        assert np.allclose(reference, c, rtol=0, atol=1e-16)


def benchmark() -> None:
    rand = np.random.default_rng(seed=0)
    off_diag = rand.integers(low=0, high=3, size=99)
    w = (
        np.diag(rand.integers(low=0, high=3, size=100))
        + np.diag(off_diag, k=1)
        + np.diag(off_diag, k=-1)
    )
    methods = iter(bind_methods(w))

    print(next(methods))
    print('Too slow to care about')
    print()

    for method in methods:
        print(str(method).splitlines()[0])
        n = 10
        def run():
            return method(w)
        t = timeit.timeit(run, number=n)
        print(f'{t/n*1e3:.2f} ms')
        print()


if __name__ == '__main__':
    test()
    benchmark()
Error from nested to Sherman-Morrison: 3.7e-09

<function nested_inverse at 0x000001EE1BEA2FC0>
Too slow to care about

<function sherman_morrison_op at 0x000001EE1C25F240>
269.28 ms

<function sherman_morrison_mod at 0x000001EE4D5302C0>
2.60 ms

functools.partial(<function sherman_morrison_structural at 0x000001EE4D470D60>, assume='gen')
0.32 ms

functools.partial(<function sherman_morrison_structural at 0x000001EE4D470D60>, assume='sym')
0.42 ms

functools.partial(<function sherman_morrison_structural at 0x000001EE4D470D60>, assume='her')
0.37 ms

functools.partial(<function sherman_morrison_banded at 0x000001EE4D5B6520>, l=1, u=1, a_banded=array([[nan,  2.,  1.,  1.,  0.,  0.,  0.,  0.,  0.,  0.,  2.,  1.,  2.,
0.18 ms