2
\$\begingroup\$

Here's the Rabin-Karp implementation that I've written in Python:

from functools import lru_cache
from itertools import islice
from typing import List, Optional


BASE = 139
MOD = 2 ** 31 - 1


@lru_cache(maxsize=None)
def char_to_int(c: str) -> int:
    return ord(c) + 1

def calculate_hash(s: str, length: Optional[int] = None) -> int:
    ret_hash = 0
    if length is not None:
        s = islice(s, 0, length)
    power = len(s) - 1 if length is None else length - 1
    for i, c in enumerate(s):
        ret_hash += char_to_int(c) * BASE ** power
        power -= 1
    return ret_hash % MOD

def roll_hash(prev_hash: int, prev_char: int, next_char: str, length: int) -> int:
    new_hash = ((prev_hash - char_to_int(prev_char) * BASE ** (length - 1)) * BASE) % MOD
    new_hash += char_to_int(next_char)
    return new_hash % MOD

def rabin_karp(text: str, pattern: str) -> List[int]:
    """
    Returns a list of indices where each entry corresponds to the starting index of a
    substring of `text` that exactly matches `pattern`
    """
    p_hash = calculate_hash(pattern)
    n = len(pattern)
    curr_hash = calculate_hash(text, n)
    indices = []
    if p_hash == curr_hash:
        indices.append(0)
    for i in range(1, len(text) - n + 1):
        curr_hash = roll_hash(
            curr_hash, text[i - 1], text[i + n - 1], n
        )
        if p_hash == curr_hash:
            indices.append(i)
    return indices

if __name__ == "__main__":
    with open("lorem_ipsum.txt", "r") as f:
        text = f.read()
    pattern = "lorem"
    indices = rabin_karp(text, pattern)
    print(f"indices: {indices}")

I'm trying to optimize the code as much as possible, so I've tried to do some dynamic code analysis to better understand bottlenecks. I used cProfile to understand the function calls and made changes to the code accordingly to arrive at the above code. Here is the final output from cProfile:

Ordered by: internal time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
    60301    0.091    0.000    0.091    0.000 rabin_karp.py:25(roll_hash)
        1    0.035    0.035    0.126    0.126 rabin_karp.py:30(rabin_karp)
       42    0.000    0.000    0.000    0.000 rabin_karp.py:11(char_to_int)
        2    0.000    0.000    0.000    0.000 rabin_karp.py:15(calculate_hash)
       57    0.000    0.000    0.000    0.000 {method 'append' of 'list' objects}
       42    0.000    0.000    0.000    0.000 {built-in method builtins.ord}
        3    0.000    0.000    0.000    0.000 {built-in method builtins.len}

Are there any other ways I could further optimize the code? Another interesting thing of note is that adding @lru_cache actually increases execution time as measured by timeit despite the caching mechanism reducing the number of functions calls to char_to_int() (from 120612 to 42).

\$\endgroup\$
4
  • 1
    \$\begingroup\$ pow() is quite slow operation ... dunno if you can find good enough pow() approximation for the job... \$\endgroup\$ Commented May 11, 2021 at 10:39
  • 1
    \$\begingroup\$ FYI the same task can be done with [m.start(0) for m in re.finditer(pattern, text)]. It runs in milliseconds over a file of 100K lines, while it takes seconds using your rabin_karp. \$\endgroup\$ Commented May 11, 2021 at 11:04
  • \$\begingroup\$ @Marc I appreciate your suggestion, but I wanted to implement Rabin Karp from scratch for the purpose of understanding the algorithm from first principles as well as to refresh my code optimization skills! \$\endgroup\$ Commented May 12, 2021 at 4:38
  • \$\begingroup\$ Pow() LUT didn't give much speedup ... even it was just an "direct access" array[1, 139, 19321, 2685619, etc]... . \$\endgroup\$ Commented May 12, 2021 at 13:37

1 Answer 1

2
\$\begingroup\$

Disclaimer: I did not get any noticeable performance improvements but here are a few ideas nonetheless. Also, your code looks really good to me and there is not much to improve.

In calculate_hash, you compute a power in the body of the loop. Based on the way the exponents is computed, we can tell that the different values will be: BASE ** (length - 1), BASE ** (length - 2), etc. A different option could be to compute the initial value with the '**' and then update the power using division by BASE.

We could get something like:

def calculate_hash(s: str, length: Optional[int] = None) -> int:
    if length is not None:
        s = islice(s, 0, length)
    else:
        length = len(s)
    power = BASE ** (length - 1)
    ret_hash = 0
    for c in s:
        ret_hash += char_to_int(c) * power
        power /= BASE
    return ret_hash % MOD

In roll_hash, the computation is split across different lines as the first line is getting too big. In order to keep the expressions easier to read, it may be useful to define additional local variables such as: prev_int, next_int = char_to_int(prev_char), char_to_int(next_char).

Then, a few things can be spotted:

  • we compute m = (val1 % MOD) and then (m + val2) % MOD: we could perform the modulo operation just once
  • the ((prev_hash - char_to_int(prev_char) * BASE ** (length - 1)) * BASE) expression can be slightly simplified by expanding it, we get: prev_hash * BASE - char_to_int(prev_char) * BASE ** length. It is much easier to read and should be just as efficient (if not more).

At the end, the whole function becomes:

def roll_hash(prev_hash: int, prev_char: int, next_char: str, length: int) -> int:
    prev_int, next_int = char_to_int(prev_char), char_to_int(next_char)
    return (next_int + prev_hash * BASE - prev_int * BASE ** length) % MOD

As mentionned in the comment, the power operation is quite expensive. Maybe we could try to compute it just once and provide it to the roll_hash function.

Also, we could try to take this chance to use an unknown part of the pow builtin: the mod parameter. Here, we may as well compute BASE ** length % MOD so that we get to handle smaller values later on (this is an improvements when the length of the pattern is bigger than 4).

At this stage, we get:

def roll_hash(prev_hash: int, prev_char: int, next_char: str, base_pow_length: int) -> int:
    prev_int, next_int = char_to_int(prev_char), char_to_int(next_char)
    return (next_int + prev_hash * BASE - prev_int * base_pow_length) % MOD



def rabin_karp(text: str, pattern: str) -> List[int]:
    """
    Returns a list of indices where each entry corresponds to the starting index of a
    substring of `text` that exactly matches `pattern`
    """
    p_hash = calculate_hash(pattern)
    n = len(pattern)
    curr_hash = calculate_hash(text, n)
    indices = []
    if p_hash == curr_hash:
        indices.append(0)
    base_pow_length = pow(BASE, n, MOD)
    for i in range(1, len(text) - n + 1):
        curr_hash = roll_hash(
            curr_hash, text[i - 1], text[i + n - 1], base_pow_length
        )
        if p_hash == curr_hash:
            indices.append(i)
    return indices
\$\endgroup\$
1
  • \$\begingroup\$ You raise some great points! I did try storing the power operation and it didn't seem to make any noticeable performance changes, but thanks to you TIL that there is a mod parameter in the power function! \$\endgroup\$ Commented May 12, 2021 at 4:33

You must log in to answer this question.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.