Here is a new version of a segmented and wheel factorized Sieve of Eratosthenes. It currently uses mod 30 wheel factorization to eliminate multiples of 2, 3, and 5 in the sieve data structure to gain speed. It wraps the wheel with segmentation in order to reduce its memory footprint so it can scale up to N in the billions and beyond. (yeah, I know, Buzz Lightyear)
This is a follow on to an earlier version. Thanks to @GZ0 for comments including warning me about how soon Python 2.7 will go unsupported, and a huge thanks to @QuantumChris for the thorough code review, especially for encouraging me to use OOP for modularity.
I decided to use a class for everything related to the mod 30 wheel. I hope that makes the design more clear, since the wheel and segmentation code is now separate.
The performance degraded by ~1.5%. I think that is fine, since:
- perhaps more people will read it. More eyeballs on any code is a Good Thing in my opinion.
- cProfile output is more helpful because the code is more granular. Woo-hoo! It now shows that cull_one_multiple is the hot spot followed by segmentedSieve.
- it will allow replacing the multiple culling code easily, such as a mod 210 wheel (to also eliminate multiples of 7), with only tiny changes outside of the wheel class. This may make up for the degradation if done carefully.
Please let me know what you think.
#!/usr/bin/python3 -Wall
"""program to find all primes <= n, using a segmented wheel sieve"""
from sys import argv
from math import log
from time import time
# non standard packages
from bitarray import bitarray
# tuning parameters
CUTOFF = 1e4 # small for debug
SIEVE_SIZE = 2 ** 20 # in bytes, tiny (i.e. 1) for debug
CLOCK_SPEED = 1.6 # in GHz, on my i5-6285U laptop
def progress(current, total):
"""Display a progress bar on the terminal."""
size = 60
x = size * current // total
print(f'\rSieving: [{"#" * x}{"." * (size - x)}] {current}/{total}', end="")
def seg_wheel_stats(n):
"""Returns only the stats from the segmented sieve."""
return(segmentedSieve(n, statsOnly=True))
def print_sieve_size(sieve):
print("sieve size:", end=' ')
ss = len(memoryview(sieve))
print(ss//1024, "KB") if ss > 1024 else print(ss, "bytes")
def prime_gen_wrapper(n):
"""
Decide whether to use the segmented sieve or a simpler version.
Stops recursion.
"""
return smallSieve(n + 1) if n < CUTOFF else segmentedSieve(n)
# NB: rwh_primes1 (a.k.a. smallSieve) returns primes < N.
# We need sieving primes <= sqrt(limit), hence the +1
def smallSieve(n):
"""Returns a list of primes less than n."""
# a copy of Robert William Hanks' odds only rwh_primes1
# used to get sieving primes for smaller ranges
# from https://stackoverflow.com/a/2068548/11943198
sieve = [True] * (n // 2)
for i in range(3, int(n ** 0.5) + 1, 2):
if sieve[i // 2]:
sieve[i * i // 2::i] = [False] * ((n - i * i - 1) // (2 * i) + 1)
return [2] + [2 * i + 1 for i in range(1, n // 2) if sieve[i]]
class PrimeMultiple:
"""Contains information about sieving primes and their multiples"""
__slots__ = ['prime', 'multiple', 'wheel_index']
def __init__(self, prime):
self.prime = prime
def update(self, multiple, wheel_index):
self.multiple = multiple
self.wheel_index = wheel_index
def update_new_mult(self, multiple, wheel_index, wheel):
self.update(multiple, wheel_index)
wheel.inc_mults_in_use()
class m30_wheel:
"""Contains all methods and data unique to a mod 30 (2, 3, 5) wheel"""
# mod 30 wheel factorization based on a non-segmented version found here
# https://programmingpraxis.com/2012/01/06/pritchards-wheel-sieve/
# in a comment by Willy Good
def __init__(self, sqrt):
# mod 30 wheel constant arrays
self.skipped_primes = [2, 3, 5] # the wheel skips multiples of these
self.wheel_primes = [7, 11, 13, 17, 19, 23, 29, 31]
self.wheel_primes_m30 = [7, 11, 13, 17, 19, 23, 29, 1]
self.gaps = [4,2,4,2,4,6,2,6, 4,2,4,2,4,6,2,6] # 2 loops for overflow
self.wheel_indices = [0,0,0,0,1,1,2,2,2,2, 3,3,4,4,4,4,5,5,5,5, 5,5,6,6,7,7,7,7,7,7]
self.round2wheel = [7,7,0,0,0,0,0,0,1,1, 1,1,2,2,3,3,3,3,4,4, 5,5,5,5,6,6,6,6,6,6]
# get sieving primes recursively,
# skipping over those eliminated by the wheel
self.mults = [PrimeMultiple(p) for p in prime_gen_wrapper(sqrt)[len(self.skipped_primes):]]
self.mults_in_use = 0
def inc_mults_in_use(self):
self.mults_in_use += 1
def get_skipped_primes(self):
"""Returns tiny primes which this wheel ignores otherwise"""
return self.skipped_primes
def num2ix(self, n):
"""Return the wheel index for n."""
n = n - 7 # adjust for wheel starting at 7 vs. 0
return (n//30 << 3) + self.wheel_indices[n % 30]
def ix2num(self, i):
"""Return the number corresponding wheel index i."""
return 30 * (i >> 3) + self.wheel_primes[i & 7]
def cull_one_multiple(self, sieve, lo_ix, high, pm):
"""Cull one prime multiple from this segment"""
p = pm.prime
wx = pm.wheel_index
mult = pm.multiple - 7 # compensate for wheel starting at 7 vs. 0
p8 = p << 3
for j in range(8):
cull_start = ((mult // 30 << 3)
+ self.wheel_indices[mult % 30] - lo_ix)
sieve[cull_start::p8] = False
mult += p * self.gaps[wx]
wx += 1
# calculate the next multiple of p and its wheel index
# f = next factor of a multiple of p past this segment
f = (high + p - 1)//p
f_m30 = f % 30
# round up to next wheel index to eliminate multiples of 2,3,5
wx = self.round2wheel[f_m30]
# normal multiple of p past this segment
mult = p * (f - f_m30 + self.wheel_primes_m30[wx])
pm.update(mult, wx) # save multiple and wheel index
def cull_segment(self, sieve, lo_ix, high):
"""Cull all prime multiples from this segment"""
# generate new multiples of sieving primes and wheel indices
# needed in this segment
for pm in self.mults[self.mults_in_use:]:
p = pm.prime
psq = p * p
if psq > high:
break
pm.update_new_mult(psq, self.num2ix(p) & 7, self)
# sieve the current segment
for pm in self.mults[:self.mults_in_use]:
# iterate over all prime multiples relevant to this segment
if pm.multiple <= high:
self.cull_one_multiple(sieve, lo_ix, high, pm)
def segmentedSieve(limit, statsOnly=False):
"""
Sieves potential prime numbers up to and including limit.
statsOnly (default False) controls the return.
when False, returns a list of primes found.
when True, returns a count of the primes found.
"""
# segmentation originally based on Kim Walisch's
# simple C++ example of segmantation found here:
# https://github.com/kimwalisch/primesieve/wiki/Segmented-sieve-of-Eratosthenes
assert(limit > 6)
sqrt = int(limit ** 0.5)
wheel = m30_wheel(sqrt)
lim_ix = wheel.num2ix(limit)
sieve_bits = SIEVE_SIZE * 8
while (sieve_bits >> 1) >= max(lim_ix, 1):
sieve_bits >>= 1 # adjust the sieve size downward for small N
sieve = bitarray(sieve_bits)
num_segments = (lim_ix + sieve_bits - 1) // sieve_bits # round up
show_progress = False
if statsOnly: # outer loop?
print_sieve_size(sieve)
if limit > 1e8:
show_progress = True
outPrimes = wheel.get_skipped_primes() # these may be needed for output
count = len(outPrimes)
# loop over all the segments
for lo_ix in range(0, lim_ix + 1, sieve_bits):
high = wheel.ix2num(lo_ix + sieve_bits) - 1
sieve.setall(True)
if show_progress:
progress(lo_ix // sieve_bits, num_segments)
wheel.cull_segment(sieve, lo_ix, high)
# handle any extras in the last segment
top = lim_ix - lo_ix + 1 if high > limit else sieve_bits
# collect results from this segment
if statsOnly:
count += sieve[:top].count() # a lightweight way to get a result
else:
for i in range(top): # XXX not so lightweight
if sieve[i]:
x = i + lo_ix
# ix2num(x) inlined below, performance is sensitive here
p = 30 * (x >> 3) + wheel.wheel_primes[x & 7]
outPrimes.append(p)
if show_progress:
progress(num_segments, num_segments)
print()
return count if statsOnly else outPrimes
if __name__ == '__main__':
a = '1e8' if len(argv) < 2 else argv[1]
n = int(float(a))
start = time()
count = segmentedSieve(n, statsOnly=True)
elapsed = time() - start
BigOculls = n * log(log(n, 2), 2)
cycles = CLOCK_SPEED * 1e9 * elapsed
cyclesPerCull = cycles/BigOculls
print(f"pi({a}) = {count}")
print(f"{elapsed:.3} seconds, {cyclesPerCull:.2} cycles/N log log N)")
if count < 500:
print(segmentedSieve(n))
Performance Data:
$ ./v51_segwheel.py 1e6
sieve size: 64 KB
pi(1e6) = 78498
0.00406 seconds, 1.5 cycles/N log log N)
$ ./v51_segwheel.py 1e7
sieve size: 512 KB
pi(1e7) = 664579
0.0323 seconds, 1.1 cycles/N log log N)
$ ./v51_segwheel.py 1e8
sieve size: 1024 KB
pi(1e8) = 5761455
0.288 seconds, 0.97 cycles/N log log N)
$ ./v51_segwheel.py 1e9
sieve size: 1024 KB
Sieving: [############################################################] 32/32
pi(1e9) = 50847534
2.79 seconds, 0.91 cycles/N log log N)
The cycles per N log log N shrink as the sieve size grows, probably due to a higher ratio of optimized sieving code to initialization and everything else. The sieve size is capped at 1MB; that produces the fastest results for N in the billions perhaps due to how it almost fits in the L2 0.5MB CPU cache. For the smaller sieve sizes, there should only be one segment. The progress bar starts appearing - possible ADD issues here :-( .
N = 1e9 (one billion) is the performance sweet spot at present. Beyond that, you can see the cycles per N log log N starting to creep up:
$ ./v51_segwheel.py 1e10
sieve size: 1024 KB
Sieving: [############################################################] 318/318
pi(1e10) = 455052511
35.3 seconds, 1.1 cycles/N log log N)
I've run the earlier version up to 1e12 (1 trillion). But that's no fun for someone with mild ADD. It takes a good part of a day. The progress bar starts to be very useful. I had to keep my eye on the laptop to prevent it from hibernating as much as possible. Once when it did hibernate and I woke it up, my WSL Ubuntu bash terminal froze, but I was able to hit various keys to salvage the run.
The hot spots:
$ python3 -m cProfile -s 'tottime' ./v51_segwheel.py 1e9 | head -15
...
ncalls tottime percall cumtime percall filename:lineno(function)
77125 1.664 0.000 1.736 0.000 v51_segwheel.py:112(cull_one_multiple)
2/1 1.188 0.594 3.049 3.049 v51_segwheel.py:153(segmentedSieve)
33 0.083 0.003 1.837 0.056 v51_segwheel.py:136(cull_segment)
80560 0.075 0.000 0.075 0.000 v51_segwheel.py:64(update)
32 0.012 0.000 0.012 0.000 {method 'count' of 'bitarray._bitarray' objects}
3435 0.009 0.000 0.015 0.000 v51_segwheel.py:68(update_new_mult)
WHAT I'M LOOKING FOR
- Performance enhancements.
- I'm using a bitarray as the sieve. If you know of something that performs better as a sieve, please answer.
- Help here:
# collect results from this segment
if statsOnly:
count += sieve[:top].count() # a lightweight way to get a result
else:
for i in range(top): # XXX not so lightweight
if sieve[i]:
x = i + lo_ix
# ix2num(x) inlined below, performance is sensitive here
p = 30 * (x >> 3) + wheel.wheel_primes[x & 7]
outPrimes.append(p)
The statsOnly leg is great because bitarray is doing the work in optimized C no doubt. I think the else leg could be shrunk. It would be fantastic to change the else into a generator, i.e. yield the primes. I tried that, but then had problems getting it to return the count when the recursion unwound to the top level. It seemed to be stuck in generator mode and didn't want to be bi-modal.
algorithmic advice. I chose a mod 30 wheel vs. mod 210 because the former has 8 teeth allowing shifts and & ops to replace divide and mod. But I see that there are only a couple of places where the bit hacks are used in the critical paths, so eliminating multiples of 7 from the data structure/culling code may be a win.
Ways to shrink, clarify, or further modularize the code.
- Help with the class stuff. This is my first voluntary OOP effort. I did dabble in JUnit back when I worked for {bigCo}. That gave me a bad taste for objects, but in retrospect, the badness was probably due to the JVM. Not a problem in Python.
EDIT
- Updated the code with a new version which adds the PrimeMultiple class in place of three separate arrays. No noticeable change in performance.
- Added the performance info and "what I want" sections.
- Minor wording tweaks to the original post