4
\$\begingroup\$

I have the following implementation that takes in symmetric matrix W and returns matrix C

import torch
import numpy as np
import time

W = torch.tensor([[0, 1, 0, 0, 0, 0, 0, 0, 0],
              [1, 0, 1, 0, 0, 1, 0, 0, 0],
              [0, 1, 0, 3, 0, 0, 0, 0, 0],
              [0, 0, 3, 0, 1, 0, 0, 0, 0],
              [0, 0, 0, 1, 0, 1, 1, 0, 0],
              [0, 1, 0, 0, 1, 0, 0, 0, 0],
              [0, 0, 0, 0, 1, 0, 0, 1, 0],
              [0, 0, 0, 0, 0, 0, 1, 0, 1],
              [0, 0, 0, 0, 0, 0, 0, 1, 0]])

A = W.clone()    
[n_node, _] = A.shape
nI = n_node * torch.eye(n_node)

# -- method 1
C = torch.empty(n_node, n_node)   

for i in range(n_node):
    for j in range(i, n_node):
        B = A.clone()
        B[i, j] = B[j, i] = 0
        C[i, j] = C[j, i] = torch.inverse(nI - B)[i, j]

As you can see, I have a matrix nI - B and change one element at each loop and compute its inverse. Im trying to use Sherman-Morrison formula to enhance the performance of the code. Here is my implementation:

# -- method 2
c = torch.empty(n_node, n_node)

inv_nI_A = torch.inverse(nI - A)
b = torch.div(A, 1 + torch.einsum('ij,ij->ij', A, inv_nI_A))
inv_nI_A_ = inv_nI_A.unsqueeze(1)

for i in range(n_node):
    for j in range(i, n_node):
        c[i, j] = c[j, i] = (inv_nI_A - b[i, j] * torch.matmul(inv_nI_A_[:, :, i], inv_nI_A_[j, :, :]))[i, j]

I was wondering if I can do further enhancements on my implementation. Thanks!

PS: W is not necessarily a sparse matrix.

\$\endgroup\$
3
  • \$\begingroup\$ Based on the Sherman-Morrison formula, it seems that further mathematical simplifications can be done to throw away for-loops completely (and replace them with implicit vectorized PyTorch operations). BTW, your 'method 1' code does not incorporate all the improvements in the code of my previous review. E.g., copying the entire matrix using A.clone() in each loop might be more readable but it is really inefficient when only two elements need to be changed. \$\endgroup\$ Commented Nov 19, 2020 at 19:37
  • \$\begingroup\$ @GZ0 1. It isn't immediately clear to me how to get rid of the final loop. 2. Thanks for bringing the 'A.clone()' part to my attention. 3. I forgot to mark your answer as accepted, thanks for the review again! \$\endgroup\$ Commented Nov 20, 2020 at 17:06
  • \$\begingroup\$ Inside each loop only one element needs to be calculated. Based on the Sherman-Morrison formula, it is possble to do just that rather than computing the entire matrix inverse for accessing the element. \$\endgroup\$ Commented Nov 21, 2020 at 21:37

1 Answer 1

5
\$\begingroup\$

I'll be disregarding torch, as this doesn't have a lot to do with machine learning specifically. Let's instead develop a demonstration in more generic numerical libraries, Numpy and Scipy. Most of the methods rhyme anyway, so this wouldn't be too difficult to port back to torch.

The output of your methods 1 and 2 are close to within ~1e-9 (but not machine precision). I assume that this is not a problem, and I also assume that you've implemented Sherman-Morrison correctly.

Your second method needs standard vectorisation. After this process, no loops are necessary.

You write that

W is not necessarily a sparse matrix

This matters, and depending on the structure of the matrix you should adopt different methods for the inverse. In particular, your sample matrix is banded, symmetric and Hermitian. As I show below, if you know which of these holds true for any given call, you can make the inverse more efficient. However, in terms of time consumption vectorisation is far and away more important than improving the inverse.

For the banded case, there is setup needed (for which I show a very inefficient method); if your matrix stays banded and L and U are low, again the inverse speeds up.

The benchmark uses a matrix that has the same value range, is larger, still banded-Hermitian and has lower L and U. Again, for your actual data you'll probably need to benchmark again to find which method is most appropriate.

import timeit
import typing
from functools import partial

import numpy as np
import scipy.linalg


def nested_inverse(A: np.ndarray) -> np.ndarray:
    """OP method 1"""
    n_node, _ = A.shape
    nI = n_node * np.eye(n_node)
    C = np.empty((n_node, n_node))

    for i in range(n_node):
        for j in range(i, n_node):
            B = A.copy()
            B[i, j] = B[j, i] = 0
            C[i, j] = C[j, i] = np.linalg.inv(nI - B)[i, j]
    return C


def sherman_morrison_op(A: np.ndarray) -> np.ndarray:
    """OP method 2"""
    n_node, _ = A.shape
    nI = n_node * np.eye(n_node)
    c = np.empty((n_node, n_node))

    inv_nI_A = np.linalg.inv(nI - A)
    b = np.divide(A, 1 + np.einsum('ij,ij->ij', A, inv_nI_A))
    inv_nI_A_ = inv_nI_A[:, np.newaxis, ...]

    for i in range(n_node):
        for j in range(i, n_node):
            c[i, j] = c[j, i] = (
                inv_nI_A - b[i, j] * np.matmul(inv_nI_A_[:, :, i], inv_nI_A_[j, :, :])
            )[i, j]
    return c


def sherman_morrison_mod(a: np.ndarray) -> np.ndarray:
    n, _ = a.shape
    nI = n*np.eye(n)
    inv_nI_A = np.linalg.inv(nI - a)  # n,n
    invdiag = np.diagonal(inv_nI_A)  # n
    b = a/(1 + a*inv_nI_A)  # n,n
    c = inv_nI_A - b*np.outer(invdiag, invdiag)  # n,n
    return c


def sherman_morrison_scipy(a: np.ndarray) -> np.ndarray:
    n, _ = a.shape
    nI = n*np.eye(n)
    inv_nI_A = scipy.linalg.inv(
        a=nI - a, overwrite_a=True, check_finite=False,
    )  # n,n
    invdiag = np.diagonal(inv_nI_A)  # n
    b = a/(1 + a*inv_nI_A)  # n,n
    c = inv_nI_A - b*np.outer(invdiag, invdiag)  # n,n
    return c


def sherman_morrison_structural(
    a: np.ndarray,
    assume: typing.Literal['gen', 'sym', 'her', 'pos'] = 'gen',
) -> np.ndarray:
    n, _ = a.shape
    eye = np.eye(n)
    inv = scipy.linalg.solve(
        a=n*eye - a, overwrite_a=True, assume_a=assume,
        b=eye, overwrite_b=True, check_finite=False,
    )  # n,n
    b = a/(1 + a*inv)  # n,n
    invdiag = np.diagonal(inv)  # n
    c = inv - b*np.outer(invdiag, invdiag)  # n,n
    return c


def sherman_morrison_banded(a: np.ndarray, a_banded: np.ndarray, l: int, u: int) -> np.ndarray:
    n, _ = a.shape
    ab = -a_banded
    ab[u] = np.full(shape=n, fill_value=n)
    inv = scipy.linalg.solve_banded(
        (l, u),
        ab=ab, overwrite_ab=True, check_finite=False,
        b=np.eye(n), overwrite_b=True,
    )  # n,n
    b = a/(1 + a*inv)  # n,n
    invdiag = np.diagonal(inv)  # n
    c = inv - b*np.outer(invdiag, invdiag)  # n,n
    return c


def make_banded(W: np.ndarray) -> tuple[
    np.ndarray,
    int, int,
]:
    """Dumb way to construct an array for solve_banded()"""
    n, _ = W.shape
    for l in range(n-1, -1, -1):
        if np.count_nonzero(np.diagonal(W, offset=l)):
            break
    for u in range(n-1, -1, -1):
        if np.count_nonzero(np.diagonal(W, offset=-u)):
            break
    banded = np.full(shape=(u + 1 + l, n), fill_value=np.nan)

    for band_i in range(u+1):
        for band_j in range(u - band_i, n):
            sq_i = band_i + band_j - u
            sq_j = band_j
            banded[band_i, band_j] = W[sq_i, sq_j]
    for band_i in range(u, u+l+1):
        for band_j in range(n - (band_i - u)):
            sq_i = band_i + band_j - u
            sq_j = band_j
            banded[band_i, band_j] = W[sq_i, sq_j]
    return banded, l, u


def bind_methods(W: np.ndarray) -> tuple[
    typing.Callable[[np.ndarray], np.ndarray], ...
]:
    assert scipy.linalg.ishermitian(W)
    banded, l, u = make_banded(W)
    return (
        nested_inverse,
        sherman_morrison_op,
        sherman_morrison_mod,
        sherman_morrison_scipy,

        partial(sherman_morrison_structural, assume='gen'),
        partial(sherman_morrison_structural, assume='sym'),
        partial(sherman_morrison_structural, assume='her'),
        # 'pos' excluded; the sample is not positive-definite!

        partial(sherman_morrison_banded, l=l, u=u, a_banded=banded),
    )


def test() -> None:
    W = np.array([
        [0, 1, 0, 0, 0, 0, 0, 0, 0],
        [1, 0, 1, 0, 0, 1, 0, 0, 0],
        [0, 1, 0, 3, 0, 0, 0, 0, 0],
        [0, 0, 3, 0, 1, 0, 0, 0, 0],
        [0, 0, 0, 1, 0, 1, 1, 0, 0],
        [0, 1, 0, 0, 1, 0, 0, 0, 0],
        [0, 0, 0, 0, 1, 0, 0, 1, 0],
        [0, 0, 0, 0, 0, 0, 1, 0, 1],
        [0, 0, 0, 0, 0, 0, 0, 1, 0],
    ])

    methods = iter(bind_methods(W))
    op_nested = next(methods)(W)
    reference = next(methods)(W)
    print(f'Error from nested to Sherman-Morrison: {np.abs(op_nested - reference).max():.1e}')
    print()
    assert np.allclose(op_nested, reference, rtol=0, atol=1e-8)

    for method in methods:
        c = method(W)
        assert np.allclose(reference, c, rtol=0, atol=1e-16)


def benchmark() -> None:
    rand = np.random.default_rng(seed=0)
    off_diag = rand.integers(low=0, high=3, size=99)
    w = (
        np.diag(rand.integers(low=0, high=3, size=100))
        + np.diag(off_diag, k=1)
        + np.diag(off_diag, k=-1)
    )
    methods = iter(bind_methods(w))

    print(next(methods))
    print('Too slow to care about')

    for method in methods:
        print(str(method).splitlines()[0])
        n = 5 if '_op' in str(method) else 100
        def run():
            return method(w)
        t = timeit.timeit(run, number=n)
        print(f'{t/n*1e3:.2f} ms')


if __name__ == '__main__':
    test()
    benchmark()
Error from nested to Sherman-Morrison: 3.7e-09

<function nested_inverse at 0x000002A242822FC0>
Too slow to care about
<function sherman_morrison_op at 0x000002A242BDF2E0>
149.03 ms
<function sherman_morrison_mod at 0x000002A273EF0360>
1.36 ms
<function sherman_morrison_scipy at 0x000002A273E30E00>
0.49 ms
functools.partial(<function sherman_morrison_structural at 0x000002A273F765C0>, assume='gen')
0.32 ms
functools.partial(<function sherman_morrison_structural at 0x000002A273F765C0>, assume='sym')
0.61 ms
functools.partial(<function sherman_morrison_structural at 0x000002A273F765C0>, assume='her')
0.38 ms
functools.partial(<function sherman_morrison_banded at 0x000002A273F76660>, l=1, u=1, a_banded=array([[nan,  2.,  1.,  1.,  0.,  0.,  0.,  0.,  0.,  0.,  2.,  1.,  2.,
0.17 ms
\$\endgroup\$

You must log in to answer this question.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.