Skip to main content
1 of 4
Gareth Rees
  • 50.1k
  • 3
  • 130
  • 211

1. Encapsulate

Writing code at the top level of a module makes it hard to test the code and hard to measure its performance. It's best to encapsulate code in a function. Accordingly, I'd write:

def survivor(n):
    """Return the survivor of a circular firing squad of n people."""
    persons = list(range(1, n + 1))
    while len(persons) > 1:
        for index, _ in enumerate(persons):
            del persons[(index + 1) % len(persons)]
    return persons[0]

(The variable person is not used in the body of the for loop; it's convential to write _ for such variables, and that's what I've done here.)

2. The meaning of efficiency

Now, is this code "efficient" as the question asks? Normally in computing we use "efficient" to refer to algorithmic efficiency: that is, to the rate at which the resources used by the program grow as a function of the input, usually expressed in big-O notation.

According to this view of efficiency, the programming language does not matter: efficiency is a property of the algorithm, not of the language it's implemented in. But you've written "Python is not efficient", which suggests that you have a different understanding of this word. Maybe you could update your post to explain what you mean by it? For the remainder of this answer I'm going to assume the question intends the usual meaning.

3. It's accidentally quadratic

What's the runtime of the survivor function, expressed as a function of \$ n \$? Well, looking at the time complexity page on the Python Wiki, we can see that the del operation on a list takes \$ O(n) \$ where \$ n \$ is the length of the list, and this is executed once for each person who is killed, resulting in an overall runtime of \$ O(n^2) \$.

It's possible to check this experimentally:

>>> t = 1
>>> for i in range(8, 17):
...     t, u = timeit(lambda:survivor(2**i), number=1), t
...     print('{:6d} {:.6f} {:.2f}'.format(2**i, t, t / u))
... 
   256 0.000138 0.00
   512 0.000318 2.31
  1024 0.000560 1.76
  2048 0.001363 2.43
  4096 0.006631 4.87
  8192 0.030330 4.57
 16384 0.132857 4.38
 32768 0.534205 4.02
 65536 2.134860 4.00

You can see that for each doubling of n, the runtime increases by roughly four times, which is what we expect for an \$ O(n^2) \$ algorithm.

4. Making it linear

How can we speed this up? Well, we could avoid the expensive del operation by making a list of the survivors, instead of deleting the deceased. Consider a single trip around the circular firing squad. If there are an even number of people remaining, then the people with indexes 0, 2, 4, and so on, survive. But if there are an odd number of people remaining, then the last survivor shoots the person with index 0, so the survivors are the people with indexes 2, 4, and so on. Putting this into code form:

def survivor2(n):
    """Return the survivor of a circular firing squad of n people."""
    persons = list(range(1, n + 1))
    while len(persons) > 1:
        if len(persons) % 2 == 0:
            persons = persons[::2]
        else:
            persons = persons[2::2]
    return persons[0]

(You could shorten this, if you liked, using an expression like persons[(len(persons) % 2) * 2::2], but I don't think the small reduction in code length is worth the loss of clarity.)

Let's check that this is correct, by comparing the results with the original implementation:

>>> all(survivor(i) == survivor2(i) for i in range(1, 1000))
True

Notice how useful it is for testing that we have the code organized into functions.

Now, what's the runtime of survivor2? Again, looking at the time complexity page on the Python Wiki, we can see that the "get slice" operation takes time \$ O(k) \$ where \$ k \$ is the number of items in the slice. In this case each slice is half the length of persons, so the runtime is $$ O\big({n \over 2}\big) + O\big({n \over 4}\big) + O\big({n \over 8}\big) + \dots $$ which is \$ O(n) \$. Again, we can check that experimentally:

>>> t = 1
>>> for i in range(8, 25):
...     t, u = timeit(lambda:survivor2(2**i), number=1), t
...     print('{:8d} {:8.6f} {:.2f}'.format(2**i, t, t / u))
... 
     256 0.000034 0.00
     512 0.000048 1.40
    1024 0.000087 1.79
    2048 0.000142 1.63
    4096 0.000300 2.12
    8192 0.000573 1.91
   16384 0.001227 2.14
   32768 0.002628 2.14
   65536 0.006003 2.28
  131072 0.017954 2.99
  262144 0.043873 2.44
  524288 0.094669 2.16
 1048576 0.180889 1.91
 2097152 0.364302 2.01
 4194304 0.743028 2.04
 8388608 1.497255 2.02
16777216 3.094121 2.07

Now, for each doubling of n, the runtime increases by roughly two times, which is what we expect for an \$ O(n) \$ algorithm.

5. Making it (poly)logarithmic

Can we do even better than this? Let's look at who the survivors actually are after each trip around the firing squad:

from pprint import pprint

def survivors(n):
    """Print survivors after each round of circular firing squad with n people."""
    persons = list(range(1, n + 1))
    while len(persons) > 1:
        if len(persons) % 2 == 0:
            persons = persons[::2]
        else:
            persons = persons[2::2]
        pprint(persons, compact=True)

>>> survivors(100)
[1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31, 33, 35, 37, 39, 41,
 43, 45, 47, 49, 51, 53, 55, 57, 59, 61, 63, 65, 67, 69, 71, 73, 75, 77, 79, 81,
 83, 85, 87, 89, 91, 93, 95, 97, 99]
[1, 5, 9, 13, 17, 21, 25, 29, 33, 37, 41, 45, 49, 53, 57, 61, 65, 69, 73, 77,
 81, 85, 89, 93, 97]
[9, 17, 25, 33, 41, 49, 57, 65, 73, 81, 89, 97]
[9, 25, 41, 57, 73, 89]
[9, 41, 73]
[73]

After each round, the gap between the survivors doubles. So after one round, the survivors are every 2nd person; after two rounds, every 4th person, after three rounds, every 8th person, and so on. When there are an even number of people, the first person remains the first survivor in the next round. But when there are an odd number of people, the third person becomes the first survivor in the next round.

So to solve this problem, we don't need to remember all the survivors, we only need to remember the first survivor, the gap between survivors, and the total number of survivors. Like this:

def survivor3(n):
    """Return the survivor of a circular firing squad of n people."""
    first = 1
    gap = 1
    while n > 1:
        if n % 2: first += 2 * gap
        gap *= 2
        n //= 2
    return first

Once again, we should check that this is correct:

>>> all(survivor(i) == survivor3(i) for i in range(1, 1000))
True

How fast is this? Here, each time round the loop we just have some arithmetic operations on numbers of size no more than \$ n \$, taking \$ O(\log n) \$, and each time around the loop we divide the number of survivors by two, so the loop executes \$ O(\log n) \$ times, for an overall runtime of \$ O((\log n)^2) \$.

>>> t = 0
>>> for i in range(8, 60, 4):
...     t, u = timeit(lambda:survivor3(2**i), number=1), t
...     print('{:20d} {:8.6f} {:.6f}'.format(2**i, t, t - u))
... 
                 256 0.000014 0.000014
                4096 0.000012 -0.000001
               65536 0.000013 0.000001
             1048576 0.000015 0.000002
            16777216 0.000016 0.000002
           268435456 0.000018 0.000002
          4294967296 0.000022 0.000003
         68719476736 0.000030 0.000008
       1099511627776 0.000026 -0.000003
      17592186044416 0.000027 0.000001
     281474976710656 0.000031 0.000003
    4503599627370496 0.000032 0.000001
   72057594037927936 0.000035 0.000003

You can see that for each doubling of n, the runtime increases by a roughly constant amount, which is what we expect for a polylogarithmic algorithm.

Is this as efficient as possible? Well, we could probably squeeze out one of the factors of \$ O(\log n) \$ by being cleverer about how we manipulate the values in the loop (each operation on first only affects one bit of the result), making the algorithm \$ O(\log n) \$. But for modest sizes of firing squads, this makes no practical difference, so I'll leave it there.

Gareth Rees
  • 50.1k
  • 3
  • 130
  • 211