Skip to main content
2 of 5
deleted 24 characters in body; edited tags; edited title; added 6 characters in body
200_success
  • 145.7k
  • 22
  • 191
  • 481

Bilinear image interpolation

I have written a bilinear interpolant, which is working moderately well except that is painfuly slow. How can rewrite the code to make it faster? Using opencv directly isn't a valid answer.

import numpy as np
import numpy.linalg as la
import matplotlib
import matplotlib.pyplot as plt
from skimage.draw import line_aa
import cv2


def draw_line(img, x0, y0, x1, y1):
    r, c, w = line_aa(int(x0), int(y0), int(x1), int(y1))
    img[r, c] = w

def synth_img(sM, sN, pts_src):
    img = np.zeros((sM, sN))
    draw_line(img.T, pts_src[0][0], pts_src[0][1], pts_src[1][0], pts_src[1][1])
    draw_line(img.T, pts_src[1][0], pts_src[1][1], pts_src[2][0], pts_src[2][1])
    draw_line(img.T, pts_src[2][0], pts_src[2][1], pts_src[3][0], pts_src[3][1])
    draw_line(img.T, pts_src[3][0], pts_src[3][1], pts_src[0][0], pts_src[0][1])

    return img


sM, sN = 1440, 1450
pts_src = np.array([[ 520, 100],[ 1410, 220],[1240, 1310],[ 30, 1070]]).astype('float32')
img = synth_img(sN, sM, pts_src)

# Create a target image
M, N = 1050, 1480
img_dst = np.zeros((N, M))
pts_dst = np.array([[0, 0], [M-1, 0], [M-1, N-1], [0, N-1]], dtype='float32') 

X = cv2.getPerspectiveTransform(pts_src, pts_dst)   # SRC to DST
X_inv = la.inv(X)                                   # DST to SRC

img_exp = cv2.warpPerspective(img, X, (M, N))

for y in np.arange(img_dst.shape[0]):
    for x in np.arange(img_dst.shape[1]):

        # Find the equivalent coordinates from DST space into SRC space
        Txy = X_inv  @ [x, y, 1]
        u, v, w = Txy / Txy[-1]

        # Find the neighboring points of v and u (src space)
        n = min(max(np.floor(v).astype('int'), 0), img.shape[0]-1)
        s = min(max(np.ceil(v).astype('int'),  0), img.shape[0]-1)
        w = min(max(np.floor(u).astype('int'), 0), img.shape[1]-1)
        e = min(max(np.ceil(u).astype('int'),  0), img.shape[1]-1)

        # Find the values in neighboring values of [u, v]
        q00 = img[n, w]
        q01 = img[n, e]
        q10 = img[s, w]
        q11 = img[s, e]

        x0, x1, y0, y1 = w, e, n, s

        A = np.array([
            [1, x0, y0, x0*y0],
            [1, x0, y1, x0*y1],
            [1, x1, y0, x1*y0],
            [1, x1, y1, x1*y1]
        ])

        b = np.array([[
            q00, q01, q10, q11
        ]]).T

        a = la.lstsq(A, b, rcond=None)[0].ravel()
        z = a[0] + a[1]*u + a[2]*v + a[3]*u*v
        img_dst[y, x] = z


plt.close()
fig, ax = plt.subplots(2,2, figsize=(6,7), dpi=300, sharey=True, sharex=True)
ax[0][0].imshow(img, cmap='gray')
ax[0][0].set_title('Original')
ax[0][1].imshow(img_exp-img_dst, cmap='gray')
ax[0][1].set_title('Diff')
ax[1][0].imshow(img_dst, cmap='gray')
ax[1][0].set_title('Result')
ax[1][1].imshow(img_exp, cmap='gray')
ax[1][1].set_title('Expected')
#plt.show()
Lin
  • 357
  • 2
  • 10