«

Kattis problem Pivot: Incrementally improving the performance of a python script, until nothing makes sense anymore

I have been doing a lot of competitive programming problems on Kattis lately. It is my favourite online judge, and I have been enjoying my climb up the Irish scoreboard in an attempt to regain my former glory of first place.

In particular, today I was tackling the problem Pivot. It is quite a small, simple, self-contained problem, but in trying to optimise my answer, I learned some interesting things.

Here is the gist of the problem statement:

An O(n) Partition algorithm partitions an array A around a pivot element (pivot is a member of A) into three parts: a left sub-array that contains elements that are ≤ pivot, the pivot itself, and a right sub-array that contains elements that are > pivot.

Now the actual job in this problem is this: Starting from an array A that has n distinct integers, we partition A using one of the member of A as pivot to produce a transformed array A’. Given this transformed array A’, your job is to count how many integers that could have been the chosen pivot.

Simply, we are given the output of a partition, and have to count the number of possible values in the list that could have been used as the pivot.

We are told that the number of elements, n, is in the range 3 ≤ n ≤ 100000. We also know that all input numbers are 32-bit signed integers (but this doesn't make that big of a difference when we're working in Python).

Now, let's go down the rabbit hole of solving this problem. And then let's go a little further.

Naïve Solution

I'm not saying it's your fault / Although you could have done more

When we break it down, we see that we are trying to find each number in the input list (which I will call D) such that D[i] is greater than or equal to all D[j] where j < i and less than all D[k] where k > i. This idea may be implemented like the following.

n = int(input())

d = list(map(int, input().split()))

ans = 0  
for i in range(n):  
    ok = True
    for j in range(i):
        if d[i] < d[j]:
            ok = False
            break
    for k in range(i+1, n):
        if d[i] > d[k]:
            ok = False
            break
    if ok:
        ans += 1
print(ans)  

This solution has time complexity O(n^2), and it exceeds the 1 second time limit on the 4th test case on Kattis. It is simply not fast enough.

This is because, for every one of the (up to) one million input numbers, we are iterating over (in the worst case) every other one of the numbers. That's a lot of comparisons.

There are some minor optimisations you could do, like skipping looping over k if ok is already False, but I doubt these could squeeze it under the time limit. However, probably worth trying some optimisations if you have no other ideas, you're in the squeaky bum time of a programming competition, and you're desperate to get a few extra points.

O(n) fleek

I knew that O(n^2) wouldn't suffice when I read the question - 10^12 comparisons is too many to do in one second. However, I realised straight away that this problem could be solved in O(n) time.

You don't need to compare each number with every number before it to see if it is greater than or equal to each of them. You only need to compare the number to the maximum number before it. Likewise, you only need to compare the number to the minimum number after it.

You can build a list, which I call left below, of "the largest number to the left of this element", in O(n) time.

You can build an equivalent list right of "the smallest number to the right of this element" in similar fashion.

Reading the code is probably easier to understand than me rambling about it. Here is the first solution I submitted.

n = int(input())

d = list(map(int, input().split()))

left = [0 for _ in range(n)]  
left[0] = 0  
for i in range(1, n):  
    left[i] = max(left[i-1], d[i-1])

right = [2**32 - 1 for _ in range(n)]  
for i in range(n-2, -1, -1):  
    right[i] = min(right[i+1], d[i+1])

ans = 0  
for i in range(n):  
    if d[i] >= left[i] and d[i] < right[i]:
        ans += 1
print(ans)  

Kattis told me this ran in 0.16 seconds, well below the time limit, and suitable to get full points for this problem. Job done!

Kattis has a nice feature that shows you a scoreboard of the ten fastest solutions to each problem, broken down by language. I often go and check out this table after I solve a problem, to see how my solution compares.

On peeking at the scoreboard for this problem, I was shooketh. Not only was the fastest Python 3 submission a mere 0.06 seconds, but it was submitted by my classmate/arch-nemesis/competitive programming teammate Cian Ruane. The horror! I couldn't live with his submission being a full tenth of a second faster; imagine how big his head would be if he saw this! I must do better.

Speedup #1

Jealousy's what the cheddar brings, for the cheddar it's anything goes

Quickly (I submitted the new solution just 1 minute 18 seconds later), I realised that I was doing twice the work I needed to. Instead of building the left list, I could just keep a running tally of the largest number seen so far while I'm iterating the list at the end to compute the answer. This will make the solution use around half as much space, will save an entire iteration of the array, and will surely improve my speed.

Maybe this is how Cian got a better result? Following is what I submitted.

n = int(input())

d = list(map(int, input().split()))

right = [2**32 - 1 for _ in range(n)]  
for i in range(n-2, -1, -1):  
    right[i] = min(right[i+1], d[i+1])

ans = 0  
maxn = 0  
for i in range(n):  
    if d[i] >= maxn and d[i] < right[i]:
        ans += 1
    maxn = max(maxn, d[i])
print(ans)  

It ran in 0.14s. Not even fast enough to appear at 10th place on the scoreboard, never mind first.

From previous experience, I've seen that the most trivial Python problem runs in 0.02 seconds. I assume this is the cost of starting the Python interpreter, parsing the program, etc. Once a problem runs in < 0.05s or so, there is likely little space to speed it up. However, we're not at that limit yet. Besides, I can see someone else did it in 0.06s, and the rest of the scoreboard is at 0.07s - 0.09s, so it must be doable!

4 minutes older, 0.04s faster

2 minutes and 45 seconds later, I submitted the following:

n = int(input())

d = list(map(int, input().split()))

right = [2**32 - 1 for _ in range(n)]  
for i in range(n-2, -1, -1):  
    right[i] = right[i+1]
    if d[i+1] < right[i]:
        right[i] = d[i+1]

ans = 0  
maxn = 0  
for i in range(n):  
    if d[i] >= maxn and d[i] < right[i]:
        ans += 1
    if d[i] > maxn:
        maxn = d[i]
print(ans)  

I have seen in several places people avoiding the builtin Python max and min, and looking at this problem, I don't see any other obvious potential speedups. It makes sense to me that these could be slower than simple if statements; not only is there the overhead of calling a function every time, but the functions also support getting the max or minimum values from an iterable, so must be some logic in there to figure out if that is necessary that we can avoid.

This solution ran in 0.12s.

Cian's solution still took half that amount of time. My curiosity was piqued; how did he do it? I decided to read his solution and find out - it could teach me a lot.

Peeking behind the curtain

I went to the github repo of his Kattis solutions and looked at his solution:

def main():  
    N = int(input())
    a = [int(x) for x in input().split()]
    mx = -1
    max_left = [None] * N
    for i in range(N):
        if a[i] > mx:
            mx = a[i]
        max_left[i] = mx
    mn = mx
    possible = 0
    for i in range(N - 1, -1, -1):
        if a[i] < mn:
            mn = a[i]
        if mn == max_left[i]:
            possible += 1
    print(possible)

if __name__ == '__main__':  
    main()

His solution is... more or less identical. He does his lists in the opposite order to me, but that shouldn't make a difference?

He has one fewer comparison in his final loop - maybe that's it? He builds his list of input values in a list comprehension instead of a map - that could be slightly more optimal? Maybe I'm missing something in his iteration orders, and there is some secret smart speedup going on?

I stole some of his ideas, and submitted the following.

n = int(input())

d = [int(x) for x in input().split()]

left = [None] * n  
maxn = 0  
for i in range(n):  
    if d[i] > maxn:
        maxn = d[i]
    left[i] = maxn

ans = 0  
minn = maxn  
for i in range(n-1, -1, -1):  
    if d[i] < minn:
        minn = d[i]
    if minn == left[i]:
        ans += 1
print(ans)  

This ran in 0.09s.

WTF

Defeated, I fired him a message.

Noah: @Cian Ruane I have pretty much the exact same solution as you for problem pivot, but yours is 0.03s faster, wtf

Cian: That's incredible

...

Noah: do you get a magic speedup from the if __name__ == '__main__': line

Noah: did you get lucky with the server load or something

...

Cian: Try putting it into a function

Noah: surely that'd be slower

Cian: Maybe it's doing some sort of jit stuff or soemthing

Sure, why not try? I submitted the following.

def main():  
    n = int(input())

    d = [int(x) for x in input().split()]

    left = [None] * n
    maxn = 0
    for i in range(n):
        if d[i] > maxn:
            maxn = d[i]
        left[i] = maxn

    ans = 0
    minn = maxn
    for i in range(n-1, -1, -1):
        if d[i] < minn:
            minn = d[i]
        if minn == left[i]:
            ans += 1
    print(ans)

if __name__ == '__main__':  
    main()

This solution ran in 0.06s.

WTF 2: Electric Boogaloo

Why does putting the code in a function speed it up? Surely the overhead of jumping to a function should make the code slower, and not faster? Is Cian right, and that there's some JIT stuff going on?

No - but Stack Overflow has the answer:

Scharron: Just an intuition, not sure if it's true: I would guess it's because of scopes. In the function case, a new scope is created (i.e. kind of a hash with variable names bound to their value). Without a function, variables are in the global scope, when you can find lot of stuff, hence slowing down the loop

katrielalex: @Scharron you're half correct. It is about scopes, but the reason it's faster in locals is that local scopes are actually implemented as arrays instead of dictionaries (since their size is known at compile-time).

katriealex then goes into more interesting details about variable scopes in a full answer.

ecatmur also looks at the bytecode and shows how it comes out in the interpreter.

Super interesting.

Postmortem

Having seen all of this, what have I learned?

  1. There is no big downside, but a potential upside, to putting my scripts inside a function, instead of in the global scope.
    • It is faster, as I saw here. This speedup I guess becomes less obvious when there are fewer variable accesses in the script.
    • Some of the answers here mention benefits of the code not being run on module import.
    • It creates cleaner code (in my opinion). I was just too lazy to do it for throwaway competitive programming solutions before.
  2. I should consider using an if instead of max and min if I am only comparing 2 values, and I am worried about performance.
    • The code is not as nice, but it could save valuable time if it is run often.
  3. I should learn off this page about python performance tips that katriealex linked. There are some great details about the innards of Python there.
  4. I should keep on Cian's good side, so that he continues to be on my team and not on a competing team. That said, I have done better than him in some, previous, competitions, and I will never let him forget those.
Share Comment on Twitter