A few hints for improving efficiency:
When
W[i, j] == 0,Bequals toWsotmpremains unchanged. In this case, the computation ofCvalues can be done only once outside the loop instead of inside the loop.torch.nonzero/torch.Tensor.nonzerocan be used to obtain all non-zero indices of a tensor.Since
Wis symmetric,Cis also symmetric and only half of its values needs to be computed.All repeated computation can be moved out of the loop to save effciency.
Improved code:
n = len(W)
nIW = n * torch.eye(n) - W
C = nIW.inverse()
for i, j in W.nonzero():
if i < j:
v = nIW[i, j].item()
nIW[i, j] = nIW[j, i] = 0
C[j, i] = C[i, j] = nIW.inverse()[i, j]
nIW[i, j] = nIW[j, i] = v
As for backpropagation, I have no idea how it can be done on elements of matrix inverses.