Skip to main content
4 of 4
edited title

Fast complex absolute argmax in Cython

I'm thinking I implemented it optimally, but somehow it's much slower than what should be much slower, np.argmax(np.abs(x)). Where am I off?

Code rationale & results

  • Mathematically, abs is sqrt(real**2 + imag**2), but argmax(abs(x)) == argmax(abs(x)**2), so no need for square root
  • np.abs(x) also allocates and writes an array. Instead I overwrite a single value, current_abs2, which should eliminate allocation and only leave writing
  • Argmax logic should be identical to NumPy's (I've not checked but only one best way to do it?)
  • Views (R, I) are for... I don't recall, saw somewhere

So savings are in dropping sqrt and len(x)-sized allocation. Yet it's much slower...

%timeit np.argmax(np.abs(x))
%timeit abs_argmax(x.real, x.imag)
409  µs ± 2.33 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
3.09 ms ± 14.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

Here's the generated C code, just the function; the whole _optimized.c is 26000 lines.

The following Numba achieves 108 µs, very satisfactory, though I'm interested in why Cython fails.

Code

import cython

@cython.boundscheck(False)
@cython.wraparound(False)
cpdef int abs_argmax(double[:] re, double[:] im):
    # initialize variables
    cdef Py_ssize_t N = re.shape[0]
    cdef double[:] R = re  # view
    cdef double[:] I = im  # view

    cdef Py_ssize_t i = 0
    cdef int max_idx = 0
    cdef double current_max = 0
    cdef double current_abs2 = 0

    # main loop
    while i < N:
        current_abs2 = R[i]**2 + I[i]**2
        if current_abs2 > current_max:
            max_idx = i
            current_max = current_abs2
        i += 1

    # return
    return max_idx

Setup & execution

I use python setup.py build_ext --inplace, setup.py shown at bottom. Then,

import numpy as np
from _optimized import abs_argmax

x = np.random.randn(100000) + 1j*np.random.randn(100000)
%timeit np.argmax(np.abs(x))
%timeit abs_argmax(x.real, x.imag)

setup.py (I forget the rationale, just took certain recommendations)

from distutils import _msvccompiler
_msvccompiler.PLAT_TO_VCVARS['win-amd64'] = 'amd64'

from setuptools import setup, Extension
from Cython.Build import cythonize
import numpy as np

setup(
    ext_modules=cythonize(Extension("_optimized", ["_optimized.pyx"]),
                          language_level=3),
    include_dirs=[np.get_include()],
)

Environment

Windows 11, i7-13700HX CPU, Python 3.11.4, Cython 3.0.0, setuptools 68.0.0, numpy 1.24.4