4
\$\begingroup\$

I'm practicing problems for the ICPC competition, and one of the problems requires solving it by using an FFT to compute the product of two polynomials efficiently. Since this is for the ICPC competition, I have to implement the FFT from scratch and can only use the standard libraries (numpy is not allowed). I've arrived at this solution so far, but it's still too slow. Any tips on how I can improve its performance?

@cache
def reverse(num, lg_n):
    res = 0
    for i in range(lg_n):
        if num & (1 << i):
            res |= 2 ** (lg_n - 1 - i)
    return res


def fft_inplace(p: list[complex], inverse: bool = False) -> None:
    n = len(p)
    lg_n = ceil(log2(n))

    for i in range(n):
        rev = reverse(i, lg_n)
        if i < rev:
            p[i], p[rev] = p[rev], p[i]

    i = 2
    while i <= n:
        ang = 2 * pi / i * (1 if inverse else -1)
        w = exp(ang * 1j)

        for j in range(0, n, i):
            current_w = 1

            for k in range(i // 2):
                u = p[j + k]
                v = p[j + k + i // 2] * current_w
                p[j + k] = u + v
                p[j + k + i // 2] = u - v

                current_w *= w
        i *= 2

    if inverse:
        for i in range(n):
            p[i] /= n


def multiply_polynomials(p: list[float], q: list[float]) -> list[float]:
    """Calcula o produto de dois polinômios"""
    fft_inplace(p)  # type: ignore
    fft_inplace(q)  # type: ignore
    r = [a * b for a, b in zip(p, q)]
    fft_inplace(r, inverse=True)  # type: ignore
    return [val.real for val in r]

This is the flamegraph from running cProfile:

enter image description here

\$\endgroup\$
1

1 Answer 1

4
\$\begingroup\$

I assumed the following modules were used:

from math import ceil, log2, pi
from cmath import exp

I suggest the following changes for a 2x / 3x speedup (same time complexity):

  • Use a more efficient method to compute the bit reversal.
  • Cache (i, rev) pairs so you don't have to compute 3x.
  • Shuffle / combine variables to reduce needed operations.
def compute_rev_pairs(n):
    output = []
    rev = 0
    for i in range(1, n):
        bit = n >> 1
        while rev & bit:
            rev ^= bit
            bit >>= 1
        rev ^= bit
        if i < rev:
            output.append((i, rev))
    return output


def fft_inplace2(p, i_rev, n, ang_c):
    for i, rev in i_rev:
        p[i], p[rev] = p[rev], p[i]

    i2 = 1  # i2 = i // 2
    i = 2
    while i <= n:
        w = exp(ang_c / i)  # ang_c = ang * i

        for j in range(0, n, i):
            current_w = 1

            for jk in range(j, j + i2):
                u = p[jk]  # jk = j + k
                v = p[jk + i2] * current_w
                p[jk] = u + v
                p[jk + i2] = u - v

                current_w *= w
        i2 = i
        i <<= 1


def multiply_polynomials2(p, q):
    n = len(p)
    ang_c = pi * 2j
    rev_pairs = compute_rev_pairs(n)
    fft_inplace2(p, rev_pairs, n, ang_c)
    fft_inplace2(q, rev_pairs, n, ang_c)
    r = [a * b for a, b in zip(p, q)]
    fft_inplace2(r, rev_pairs, n, -ang_c)
    return [c.real / n for c in r]

We can put together a small test:

import numpy as np
import time

test_n = 2 ** 18
p = np.random.random(test_n)
q = np.random.random(test_n)

start1 = time.time()
pq1 = multiply_polynomials(p.tolist(), q.tolist())
diff1 = time.time() - start1

start2 = time.time()
pq2 = multiply_polynomials2(p.tolist(), q.tolist())
diff2 = time.time() - start2

print("{:.2f}s vs {:.2f}s, {:.2f}x faster".format(
    diff1, diff2, diff1 / diff2))
print("IDENTICAL: {}".format(np.allclose(pq1, pq2, atol=0)))

This yields the following output:

6.56s vs 2.48s, 2.64x faster
IDENTICAL: True

Hope this helps! The first two suggestions come from cp-algorithms.com. If you need even greater performance, unfortunately I don't see an obvious way beyond switching to something like C++.

\$\endgroup\$
2
  • \$\begingroup\$ Caching i, rev pairs would not be much of a change from @cache. \$\endgroup\$
    – greybeard
    Commented Aug 6, 2024 at 5:55
  • 1
    \$\begingroup\$ Thanks, it did help a lot! Unfortunately, I don't think the question was meant to be solved in Python, a much simpler implementation passed in C++, but I believe the time for Python in this problem is just completely wrong. \$\endgroup\$ Commented Aug 6, 2024 at 11:00

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.