Update 1:
I managed to speed up the script by 100s using PyTorch logical and and or. However, this operation is only supported for type torch.uint8 which means I have to do type conversion. Now the performance is 232 seconds on a GPU.
The following is the modified version:
import numpy as np
import torch
import time
USE_CUDA = torch.cuda.is_available()
def gen_mask(mask_pred, threshold):
mask_pred = mask_pred.clone()
mask_pred[:, :, :][mask_pred[:, :, :] < threshold] = 0
mask_pred[:, :, :][mask_pred[:, :, :] >= threshold] = 1
return mask_pred.type(torch.uint8)
def jaccard(prediction, ground_truth):
union = prediction | ground_truth
intersection = prediction & ground_truth
union = union.sum(axis=(1, 2))
intersection = intersection.sum(axis=(1, 2))
union = union.type(torch.float)
intersection = intersection.type(torch.float)
union_nonzero_idx = union != 0
cur_jaccard_idx = torch.zeros(intersection.shape)
if USE_CUDA:
cur_jaccard_idx = cur_jaccard_idx.cuda()
cur_jaccard_idx[union_nonzero_idx] = intersection[union_nonzero_idx] / union[union_nonzero_idx]
return cur_jaccard_idx
groundtruth_masks = np.load('./masks.npy')
pred_mask = np.load('./pred_mask.npy')
n_patch = groundtruth_masks.shape[0]
groundtruth_masks = torch.from_numpy(groundtruth_masks)
groundtruth_masks = groundtruth_masks.type(torch.uint8)
pred_mask = torch.from_numpy(pred_mask)
vector_pred = pred_mask.view(n_patch, -1)
best_threshold_val = torch.zeros(n_patch)
best_jaccard_idx = torch.zeros(n_patch)
if USE_CUDA:
groundtruth_masks = groundtruth_masks.cuda()
pred_mask = pred_mask.cuda()
vector_pred = vector_pred.cuda()
best_threshold_val = best_threshold_val.cuda()
best_jaccard_idx = best_jaccard_idx.cuda()
start = time.time()
# I think this outer for loop is inevitable since
# vector_pred.shape[1] is 65536
# so we cannot simply create a matrix with n_patch * 65536 * 256 * 256
# which is too large even for a GPU to handle
for i in range(vector_pred.shape[1]):
cur_threshold_val = vector_pred[:, i]
cur_threshold_val = cur_threshold_val.reshape(n_patch, 1, 1)
thresholded_mask = gen_mask(pred_mask.squeeze(), cur_threshold_val)
cur_jaccard_idx = jaccard(thresholded_mask, groundtruth_masks)
cur_threshold_val = cur_threshold_val.squeeze()
best_threshold_val[cur_jaccard_idx >
best_jaccard_idx] = cur_threshold_val[cur_jaccard_idx > best_jaccard_idx]
best_jaccard_idx[cur_jaccard_idx > best_jaccard_idx] = cur_jaccard_idx[cur_jaccard_idx > best_jaccard_idx]
print(i, '/', vector_pred.shape[1], end="\r")
end = time.time()
print(best_threshold_val)
print(best_jaccard_idx)
print(end - start)
The attached code is computing n_patch probability maps at once instead of a single probability map. However, even I have optimized the implementation as vectorized as possible, the code still runs around 330 seconds on a GPU. Note the attached code is also executable on CPU, it will use an Nvidia GPU if you have one. A modified version of the code can be found further down.
Update:
I managed to speed up the script by 100s using PyTorch logical and and or. However, this operation is only supported for type torch.uint8 which means I have to do type conversion. Now the performance is 232 seconds on a GPU.
The following is the modified version:
import numpy as np
import torch
import time
USE_CUDA = torch.cuda.is_available()
def gen_mask(mask_pred, threshold):
mask_pred = mask_pred.clone()
mask_pred[:, :, :][mask_pred[:, :, :] < threshold] = 0
mask_pred[:, :, :][mask_pred[:, :, :] >= threshold] = 1
return mask_pred.type(torch.uint8)
def jaccard(prediction, ground_truth):
union = prediction | ground_truth
intersection = prediction & ground_truth
union = union.sum(axis=(1, 2))
intersection = intersection.sum(axis=(1, 2))
union = union.type(torch.float)
intersection = intersection.type(torch.float)
union_nonzero_idx = union != 0
cur_jaccard_idx = torch.zeros(intersection.shape)
if USE_CUDA:
cur_jaccard_idx = cur_jaccard_idx.cuda()
cur_jaccard_idx[union_nonzero_idx] = intersection[union_nonzero_idx] / union[union_nonzero_idx]
return cur_jaccard_idx
groundtruth_masks = np.load('./masks.npy')
pred_mask = np.load('./pred_mask.npy')
n_patch = groundtruth_masks.shape[0]
groundtruth_masks = torch.from_numpy(groundtruth_masks)
groundtruth_masks = groundtruth_masks.type(torch.uint8)
pred_mask = torch.from_numpy(pred_mask)
vector_pred = pred_mask.view(n_patch, -1)
best_threshold_val = torch.zeros(n_patch)
best_jaccard_idx = torch.zeros(n_patch)
if USE_CUDA:
groundtruth_masks = groundtruth_masks.cuda()
pred_mask = pred_mask.cuda()
vector_pred = vector_pred.cuda()
best_threshold_val = best_threshold_val.cuda()
best_jaccard_idx = best_jaccard_idx.cuda()
start = time.time()
# I think this outer for loop is inevitable since
# vector_pred.shape[1] is 65536
# so we cannot simply create a matrix with n_patch * 65536 * 256 * 256
# which is too large even for a GPU to handle
for i in range(vector_pred.shape[1]):
cur_threshold_val = vector_pred[:, i]
cur_threshold_val = cur_threshold_val.reshape(n_patch, 1, 1)
thresholded_mask = gen_mask(pred_mask.squeeze(), cur_threshold_val)
cur_jaccard_idx = jaccard(thresholded_mask, groundtruth_masks)
cur_threshold_val = cur_threshold_val.squeeze()
best_threshold_val[cur_jaccard_idx >
best_jaccard_idx] = cur_threshold_val[cur_jaccard_idx > best_jaccard_idx]
best_jaccard_idx[cur_jaccard_idx > best_jaccard_idx] = cur_jaccard_idx[cur_jaccard_idx > best_jaccard_idx]
print(i, '/', vector_pred.shape[1], end="\r")
end = time.time()
print(best_threshold_val)
print(best_jaccard_idx)
print(end - start)