Memoization Decorator in Python

In programming, memoization is an optimization technique to improve execution speed of computer programs by caching previous output of function call for some inputs.

After caching, if same input occurs again then function call is not made but it is returned from cache which speeds up the execution time.

Understanding Need for Memoization

To understand memoization better, lets consider following recursive function for calculating nth term of Fibonacci term:


# Fibonacci recusrsion (Not momoized yet)

def fib(n):
    print('calculating term - {0}'.format(n))
    if n < 3:
        return 1
    else:
        return fib(n-1) + fib(n-2)

print('6th term is: ',fib(6))

Output of the above program is:

calculating term - 6
calculating term - 5
calculating term - 4
calculating term - 3
calculating term - 2
calculating term - 1
calculating term - 2
calculating term - 3
calculating term - 2
calculating term - 1
calculating term - 4
calculating term - 3
calculating term - 2
calculating term - 1
calculating term - 2
6th term is:  8

If we look at output, term-4 is calculated 2 times, term-3 is calculated 3 times, term-2 is calculated 5 times. We are making function call again and again just to calculate same thing. This is not good. This can be avoided by caching or memoizing already calculated term.

Now, we implement memoization using Python decorator and then we will compare the result.

Python Source Code: Memoization


# Momoizer decorator function
def memoizer(fn):
    # first and second term of Fibonacci series
    cache = {}
    
    # Closure
    def inner(n):
        if n not in cache:
            cache[n] = fn(n)
        return cache[n]
    
    # Returning closure
    return inner

# Decorator in action
@memoizer
def fib(n):
    print('calculating - {0}'.format(n))
    if n < 3:
        return 1
    else:
        return fib(n-1) + fib(n-2)

    
print('--------------------------')
print('5th term: ', fib(5))
print('--------------------------')
print('10th term: ', fib(10))
print('--------------------------')
print('12th term: ', fib(12))
print('--------------------------')
print('8th term: ', fib(8))

Output

--------------------------
calculating - 5
calculating - 4
calculating - 3
calculating - 2
calculating - 1
5th term:  5
--------------------------
calculating - 10
calculating - 9
calculating - 8
calculating - 7
calculating - 6
10th term:  55
--------------------------
calculating - 12
calculating - 11
12th term:  144
--------------------------
8th term:  21

In the above program no calculation is repeated while calculating term-5. Similarly while calculating term-10, it only calculates term-6, 7, 8, 9 & 10 because all other terms are already calculated while calculating term-5 earlier and so on.