Skip to main content
2 of 16
added 258 characters in body

PyTorch Vectorized Implementation for Thresholding and Computing Jaccard Index

I have been trying to optimize a code snippet find the optimal threshold value in a 256 * 256 probability map to get the highest Jaccard index against ground truth mask. However, even I have optimized the implementation as vectorized as possible, the code still runs around 330 seconds on a GPU. The idea of the code snippet is to loop through all the probability in the probability map and take each probability as a threshold value to threshold the predicted probability map.

The data are available in here (around 24MB). The file named mask.npy is a n_patch * 256 * 256 binary (contains only 0 and 1) and the file named pred_mask.npy is a n_patch * 256 * 256 probability (contains 0 to 1 probability) maps.

The threshold method is implemented gen_mask and it takes a 3D pred_mask and threshold on each dimension based on a threshold value vector. The jaccard computes the Jarrard index of a 3D thresholded mask agains the ground truth and returned a n_patch * 1 shape array.

import numpy as np
import torch
import time


def gen_mask(mask_pred, threshold):
    mask_pred = mask_pred.clone()
    mask_pred[:, :, :][mask_pred[:, :, :] < threshold] = 0
    mask_pred[:, :, :][mask_pred[:, :, :] >= threshold] = 1
    return mask_pred


def jaccard(prediction, ground_truth):
    union = prediction + ground_truth
    union[union == 2] = 1
    intersection = prediction * ground_truth
    union = union.sum(axis=(1, 2))
    intersection = intersection.sum(axis=(1, 2))
    ji_nonezero = intersection[union != 0] / union[union != 0]
    ji = torch.zeros(intersection.shape).cuda()
    ji[union != 0] = ji_nonezero
    return ji


n_patch = 32
masks = np.load('./masks.npy')
pred_mask = np.load('./pred_mask.npy')

masks = torch.from_numpy(masks).cuda()
masks = masks.type(torch.float)
pred_mask = torch.from_numpy(pred_mask).cuda()

vector_pred = pred_mask.view(n_patch, -1).cuda()

best_thes = torch.zeros(n_patch)
best_ji = torch.zeros(n_patch)
best_thes = best_thes.cuda()
best_ji = best_ji.cuda()

start = time.time()
# I think this outer for loop is envitable since
# vector_pred.shape[1] is 65536
# so we cannot simply create a matrix with n_patch * 65536 * 256 * 256
# which is too large even for a GPU to handle
for i in range(vector_pred.shape[1]):
    cur_thres = vector_pred[:, i]
    cur_thres = cur_thres.cuda()
    cur_thres = cur_thres.reshape(n_patch, 1, 1)
    gend = gen_mask(pred_mask.squeeze(), cur_thres)
    gend = gend.type(torch.float)
    ji = jaccard(gend, masks)
    cur_thres = cur_thres.squeeze()
    best_thes[ji > best_ji] = cur_thres[ji > best_ji]
    best_ji[ji > best_ji] = ji[ji > best_ji]
    print(i, '/', vector_pred.shape[1], end="\r")

print(best_thes)
print(best_ji)
end = time.time()
print(end - start)

Also, the output, which is correct:

Best Threshold: tensor([6.8828e-01, 4.7082e-01, 1.2254e-01, 3.4189e-01, 2.8555e-01, 2.4655e-01, 4.9444e-01, 5.9245e-01, 5.0390e-01, 1.7931e-01, 2.3205e-01, 3.8314e-01, 4.5103e-01, 3.6109e-01, 3.4614e-01, 3.8766e-01, 3.6444e-01, 2.3667e-01, 2.0029e-01, 8.0435e-01, 4.9489e-01, 2.8066e-01, 1.4230e-04, 1.8089e-01, 2.2194e-01, 3.7781e-01, 3.5074e-01, 5.4690e-03, 2.6937e-01, 1.7834e-01, 2.2150e-01, 1.8330e-01], device='cuda:0')

Best Jaccard Index: tensor([0.9978, 0.9936, 0.9975, 0.9956, 0.9921, 0.9977, 0.9938, 0.9972, 0.9987, 0.9983, 0.9974, 0.9972, 0.9955, 0.9851, 0.9979, 0.9938, 0.9960, 0.9936, 0.9967, 0.9852, 0.9963, 0.9924, 0.9890, 0.9946, 0.9954, 0.9971, 0.9945, 0.9919, 0.9964, 0.9947, 0.9920, 0.9977], device='cuda:0')

Any suggestions to optimize the code snippet are welcome!