Skip to main content
7 of 7
added 9 characters in body
GZ0
  • 2.4k
  • 8
  • 19

A few hints for improving efficiency:

  • When W[i, j] == 0, B equals to W so tmp remains unchanged. In this case, the computation of C values can be done only once outside the loop instead of inside the loop.

  • torch.nonzero / torch.Tensor.nonzero can be used to obtain all indices of non-zero values in a tensor.

  • Since W is symmetric, C is also symmetric and only half of its values need to be computed.

  • All repeated computation can be moved out of the loop to improve effciency.

Improved code:

n = len(W)
nIW = n * torch.eye(n) - W
nIB = nIW.clone()
C = nIB.inverse()

for i, j in W.nonzero():
    if i < j:
        nIB[i, j] = nIB[j, i] = 0
        C[j, i] = C[i, j] = nIB.inverse()[i, j]
        nIB[i, j] = nIB[j, i] = nIW[i, j]

Further performance improvement may be achieved according to this.

As for backpropagation, I have no idea how it can be done on elements of matrix inverses.

GZ0
  • 2.4k
  • 8
  • 19