Skip to main content
1 of 7
GZ0
  • 2.4k
  • 8
  • 19

A few hints for improving efficiency:

  • When W[i, j] == 0, B equals to W so tmp remains unchanged. In this case, the computation of C values can be done only once outside the loop instead of inside the loop.

  • torch.nonzero can be used to obtain all non-zero indices.

  • Since W is symmetric, C is also symmetric and only half of its values needs to be computed.

  • All repeated computation can be moved out of the loop to save effciency.

Improved code:

n = len(W)
nIW = n * torch.eye(n) - W
C = torch.inverse(nIW)

for i, j in torch.nonzero(W):
    if i < j:
       v = nIW[i, j]
       nIW[i, j] = v + W[i, j]
       C[i, j] = torch.inverse(nIW)[i, j]
       nIW[i, j] = v
    else:
       C[i, j] = C[j, i]

As for backpropagation, I have no idea how it can be done on elements of matrix inverses.

GZ0
  • 2.4k
  • 8
  • 19