Automatic differentiation is the foundation upon which deep learning frameworks lie. Deep learning models are typically trained using gradient based techniques, and autodiff makes it easy to get gradients, even from enormous, complex models. ‘Reverse-mode autodiff’ is the autodiff method used by most deep learning frameworks, due to its efficiency and accuracy.

Let’s:

  • Look at how reverse-mode autodiff works.
  • Create a minimal autodiff framework in Python.

The small autodiff framework will deal with scalars. However we will look at a method of vectorising it with NumPy. We will also look at how to compute Nth order derivatives.

Note on terminology: from now on ‘autodiff’ will refer to ‘reverse-mode autodiff’. ‘Gradient’ is used loosely, but in this context generally means ‘first order partial derivative’.

How does autodiff work?

Let’s start with an example.

a = 4
b = 3
c = a + b  # = 4 + 3 = 7
d = a * c  # = 4 * 7 = 28

Q1: What is the gradient of $d$ with respect to $a$, i.e. $\frac{\partial{d}}{\partial{a}}$? (Go ahead and try this!)

Solving the ‘traditional’ way:

There’s many ways to solve Q1, but let’s use the product rule, i.e. if $ y = x_1x_2 $ then $y’ = x_1’x_2 + x_1x_2’$.

\[d = a * c\] \[\frac{\partial{d}}{\partial{a}} = \frac{\partial{a}}{\partial{a}} * c + a * \frac{\partial{c}}{\partial{a}}\] \[\frac{\partial{d}}{\partial{a}} = c + a * \frac{\partial{c}}{\partial{a}}\] \[\frac{\partial{d}}{\partial{a}} = (a + b) + a * \frac{\partial{(a + b)}}{\partial{a}}\] \[\frac{\partial{d}}{\partial{a}} = a + b + a * (\frac{\partial{a}}{\partial{a}} + \frac{\partial{b}}{\partial{a}})\] \[\frac{\partial{d}}{\partial{a}} = a + b + a * (1 + 0)\] \[\frac{\partial{d}}{\partial{a}} = a + b + a\] \[\frac{\partial{d}}{\partial{a}} = 2a + b\] \[\frac{\partial{d}}{\partial{a}} = 2*4 + 3 = 11\]

Phew… and if you wanted to know $\frac{\partial{d}}{\partial{b}}$ you’d have to carry out the process again.

Solving the autodiff way

We’ll now look at the autodiff way to solve Q1. Here is a figure:

The system as a graph.

On the left we see the system represented as a graph. Each variable is a node; e.g. $d$ is the topmost node, and $a$ and $b$ are leaf nodes at the bottom.

On the right we see the system from autodiff’s point of view. Let’s call the values on the graph edges local derivatives. By using local derivatives and simple rules, we will be able to compute the derivatives that we want.

Here is the answer to Q1, calculated the autodiff way. Can you see how it relates to the figure?

\[\frac{\partial{d}}{\partial{a}} = \frac{\partial{\bar{d}}}{\partial{a}} + \frac{\partial{\bar{d}}}{\partial{c}} * \frac{\partial{\bar{c}}}{\partial{a}}\] \[\frac{\partial{d}}{\partial{a}} = c + a * 1\] \[\frac{\partial{d}}{\partial{a}} = a + b + a\] \[\frac{\partial{d}}{\partial{a}} = 2a + b\] \[\frac{\partial{d}}{\partial{a}} = 11\]

We get this answer from the graph by finding the paths from $d$ to $a$ (not going against the dotted arrows), and then applying the following rules:

  • Multiply the edges of a path.
  • Add together the different paths.

The first path is straight from $d$ to $a$, which gives us the $\frac{\partial{\bar{d}}}{\partial{a}}$ term. The second path is from $d$ to $c$ to $a$, which gives us the term $\frac{\partial{\bar{d}}}{\partial{c}} * \frac{\partial{\bar{c}}}{\partial{a}}$.

Our autodiff implementation will go down the graph and compute the derivative of $d$ with respect to every sub-node, rather than just computing it for a particular node, as we have just done with $d$. Notice that we could compute the gradient of $d$ with respect to $c$, and $b$, without much more work.

Local derivatives

We saw ‘local derivatives’ on the graph edges above, written in the form: $ \frac{\partial \bar{y} }{\partial x}$.

The bar is to convey a simpler kind of differentiation.

In general: to get a local derivative, treat the variables going into a node as not being functions of other variables.

Note the distinction between local derivatives and partial derivatives: a partial derivative could take the whole graph into account, whereas a local derivative only looks at direct child nodes.

For example, recall that $d = a * c$. Then compare $ \frac{\partial{d}}{\partial{a}} = 2a + b $, to $ \frac{\partial \bar{d} }{\partial a} = c $. The local derivative $\frac{\partial \bar{d} }{\partial a} = c$, is obtained by treating $c$ as a constant before differentiating the expression for $d$.

Local derivatives make our lives easier.

It is often easy to define the local derivatives of simple functions, and adding functions to the autodiff framework is easy if you know the local derivatives. E.g.

Addition: $n = a + b$. The local derivatives are: $\frac{\partial \bar{n} }{\partial a} = 1$ and $\frac{\partial \bar{n} }{\partial b} = 1$.

Multiplication: $n = a * b$. The local derivatives are: $\frac{\partial \bar{n} }{\partial a} = b$ and $\frac{\partial \bar{n} }{\partial b} = a$.

Let’s create the framework

Overview of the implementation

A variable (or node) contains two pieces of data:

  • value - the value of the variable.
  • local_gradients - the variable’s children and corresponding ‘local derivatives’.

The function get_gradients uses the variables’ local_gradients data to go through the graph recursively*, computing the gradients. (*I.e… local_gradients contains references to child variables, which have their own local_gradients, which contain references to child variables, which have their own local_gradients, which contain references to child variables, etc.)

The gradient of variable with respect to a child variable is computed using the rules we saw above:

  • For each path from variable to the child variable, multiply the edges of the path (giving path_value).
  • Sum the path values.

… This gives the first order partial derivative of variable, with respect to the child variable.

Implementing just enough to solve the example above:

from collections import defaultdict

class Variable:
    def __init__(self, value, local_gradients=()):
        self.value = value
        self.local_gradients = local_gradients
    
def add(a, b):
    "Create the variable that results from adding two variables."
    value = a.value + b.value    
    local_gradients = (
        (a, 1),  # the local derivative with respect to a is 1
        (b, 1)   # the local derivative with respect to b is 1
    )
    return Variable(value, local_gradients)

def mul(a, b):
    "Create the variable that results from multiplying two variables."
    value = a.value * b.value
    local_gradients = (
        (a, b.value), # the local derivative with respect to a is b.value
        (b, a.value)  # the local derivative with respect to b is a.value
    )
    return Variable(value, local_gradients)

def get_gradients(variable):
    """ Compute the first derivatives of `variable` 
    with respect to child variables.
    """
    gradients = defaultdict(lambda: 0)
    
    def compute_gradients(variable, path_value):
        for child_variable, local_gradient in variable.local_gradients:
            # "Multiply the edges of a path":
            value_of_path_to_child = path_value * local_gradient
            # "Add together the different paths":
            gradients[child_variable] += value_of_path_to_child
            # recurse through graph:
            compute_gradients(child_variable, value_of_path_to_child)
    
    compute_gradients(variable, path_value=1)
    # (path_value=1 is from `variable` differentiated w.r.t. itself)
    return gradients

Solving the example above:

a = Variable(4)
b = Variable(3)
c = add(a, b) # = 4 + 3 = 7
d = mul(a, c) # = 4 * 7 = 28

gradients = get_gradients(d)

print('d.value =', d.value)
print("The partial derivative of d with respect to a =", gradients[a])
d.value = 28
The partial derivative of d with respect to a = 11

Success!

Note we also get gradients for the other nodes:

print('gradients[b] =', gradients[b])
print('gradients[c] =', gradients[c])
gradients[b] = 4
gradients[c] = 4

Let’s take a look at the local_gradients values (the local derivatives):

print('dict(d.local_gradients)[a] =', dict(d.local_gradients)[a])
print('dict(d.local_gradients)[c] =', dict(d.local_gradients)[c])
print('dict(c.local_gradients)[a] =', dict(c.local_gradients)[a])
print('dict(c.local_gradients)[b] =', dict(c.local_gradients)[b])
dict(d.local_gradients)[a] = 7
dict(d.local_gradients)[c] = 4
dict(c.local_gradients)[a] = 1
dict(c.local_gradients)[b] = 1

We saw these in our example above as:

$ \frac{\partial \bar{d} }{\partial a} = c = 7$

$ \frac{\partial \bar{d} }{\partial c} = a = 4$

$ \frac{\partial \bar{c} }{\partial a} = 1$

$ \frac{\partial \bar{c} }{\partial b} = 1$


A few improvements

Let’s:

  • Add a couple more functions.
  • Enable the use of operators, such as +, *, - … using Python’s dunder (double underscore) methods.
class Variable:
    def __init__(self, value, local_gradients=[]):
        self.value = value
        self.local_gradients = local_gradients
    
    def __add__(self, other):
        return add(self, other)
    
    def __mul__(self, other):
        return mul(self, other)
    
    def __sub__(self, other):
        return add(self, neg(other))

    def __truediv__(self, other):
        return mul(self, inv(other))
    
def add(a, b):
    value = a.value + b.value    
    local_gradients = (
        (a, 1),
        (b, 1)
    )
    return Variable(value, local_gradients)

def mul(a, b):
    value = a.value * b.value
    local_gradients = (
        (a, b.value),
        (b, a.value)
    )
    return Variable(value, local_gradients)

def neg(a):
    value = -1 * a.value
    local_gradients = (
        (a, -1),
    )
    return Variable(value, local_gradients)

def inv(a):
    value = 1. / a.value
    local_gradients = (
        (a, -1 / a.value**2),
    )
    return Variable(value, local_gradients)     

Some more examples

We can get the gradients of arbitrary functions made from the functions we’ve added to the framework. E.g.

def f(a, b):
    return (a / b - a) * (b / a + a + b) * (a - b)

a = Variable(230.3)
b = Variable(33.2)
y = f(a, b)

gradients = get_gradients(y)

print("The partial derivative of y with respect to a =", gradients[a])
print("The partial derivative of y with respect to b =", gradients[b])
The partial derivative of y with respect to a = -153284.83150602411
The partial derivative of y with respect to b = 3815.0389441500993

We can use numerical estimates to check that we’re getting correct results:

delta = Variable(1e-8)
numerical_grad_a = (f(a + delta, b) - f(a, b)) / delta
numerical_grad_b = (f(a, b + delta) - f(a, b)) / delta
print("The numerical estimate for a =", numerical_grad_a.value)
print("The numerical estimate for b =", numerical_grad_b.value)
The numerical estimate for a = -153284.89243984222
The numerical estimate for b = 3815.069794654846

It’s easy to add more functions

You just need to be able to define the local derivatives.

import numpy as np

def sin(a):
    value = np.sin(a.value)
    local_gradients = (
        (a, np.cos(a.value)),
    )
    return Variable(value, local_gradients)

def exp(a):
    value = np.exp(a.value)
    local_gradients = (
        (a, value),
    )
    return Variable(value, local_gradients)
    
def log(a):
    value = np.log(a.value)
    local_gradients = (
        (a, 1. / a.value),
    )
    return Variable(value, local_gradients)

Let’s check that these work:

a = Variable(43)
b = Variable(3)
c = Variable(2)

def f(a, b, c):
    f = sin(a * b) + exp(c - (a / b))
    return log(f * f) * c

y = f(a, b, c)

gradients = get_gradients(y)

print("The partial derivative of y with respect to a =", gradients[a])
print("The partial derivative of y with respect to b =", gradients[b])
print("The partial derivative of y with respect to c =", gradients[c])
The partial derivative of y with respect to a = 60.85353612046653
The partial derivative of y with respect to b = 872.2331479536114
The partial derivative of y with respect to c = -3.2853671032530305
delta = Variable(1e-8)
numerical_grad_a = (f(a + delta, b, c) - f(a, b, c)) / delta
numerical_grad_b = (f(a, b + delta, c) - f(a, b, c)) / delta
numerical_grad_c = (f(a, b, c + delta) - f(a, b, c)) / delta

print("The numerical estimate for a =", numerical_grad_a.value)
print("The numerical estimate for b =", numerical_grad_b.value)
print("The numerical estimate for c =", numerical_grad_c.value)
The numerical estimate for a = 60.85352186602222
The numerical estimate for b = 872.232160009645
The numerical estimate for c = -3.285367089489455

That’s the end of our minimal autodiff implementation!

Of course there’s various features missing, such as:

  • Vectorisation.
  • Nth derivatives.
  • Placeholder variables.
  • Optimisations.
  • The other great things deep learning / autodiff frameworks can do.

The following sections look at vectorisation and Nth derivatives.


A naive vectorisation

Let’s look at a computationally inefficient, but easy to implement, method of vectorising our autodiff framework.

The approach is:

  • Put our Variable objects from above into NumPy arrays
  • We can then use NumPy operations
  • That’s it..
import numpy as np

# convert NumPy array into array of Variable objects:
to_var = np.vectorize(lambda x : Variable(x))

# get values from array of Variable objects:
to_vals = np.vectorize(lambda variable : variable.value)

A single linear layer neural network (fitting noise to noise):

import matplotlib.pyplot as plt
np.random.seed(0)

def update_weights(weights, gradients, lrate):
    for _, weight in np.ndenumerate(weights):
        weight.value -= lrate * gradients[weight]

input_size = 50
output_size = 10
lrate = 0.001

x = to_var(np.random.random(input_size))
y_true = to_var(np.random.random(output_size))
weights = to_var(np.random.random((input_size, output_size)))

loss_vals = []
for i in range(100):
    y_pred = np.dot(x, weights)
    loss = np.sum((y_true - y_pred) * (y_true - y_pred))
    loss_vals.append(loss.value)
    gradients = get_gradients(loss)
    update_weights(weights, gradients, lrate)

plt.plot(loss_vals)
plt.xlabel("Time step")
plt.ylabel("Loss")
plt.title("Single linear layer learning")
plt.show()

Plot of the loss of the linear layer.


Nth order derivatives

Note that enabling Nth order derivatives using reverse-mode autodiff is more costly computationally than only enabling first order derivatives. (Most deep learning frameworks only compute first order derivatives. Widely used optimisation methods, such as Adam, somewhat approximate higher order derivatives from the first order derivatives.)

We can enable our framework to compute Nth order derivatives by changing our get_gradient computations to use our Variable objects. This means that autodiff graphs will be created when computing derivatives, and we can then compute the derivatives of the derivatives that we are interested in with respect to the variables that we are interested in (and then the derivatives of derivatives, and derivatives of derivatives of derivatives, etc.).

The changes we will make are:

  1. To compute gradients using Variable objects.
  2. To change local_gradients to contain functions, instead of local gradient values.

A local_gradient function, multiply_by_locgrad, will multiply the path value that is passed to it with the local gradient value. This previously happened within the scope of get_gradients. However, since we are now defining local_gradients using our own Variable operations, we need to prevent infinite recursion (e.g. the local gradients of div are calculated using div, which would then compute more local gradients using div, which would then compute more local gradients using div, etc.) - using functions prevents this.

from collections import defaultdict

class Variable:
    def __init__(self, value, local_gradients=()):
        self.value = value
        self.local_gradients = local_gradients

    def __add__(a, b):
        return add(a, b)
    
    def __mul__(a, b):
        return mul(a, b)
    
    def __truediv__(a, b):
        return div(a, b)

# Useful constants:
ONE = Variable(1.)
NEG_ONE = Variable(-1.)

def add(a, b):
    value = a.value + b.value    
    local_gradients = (
        # Note that local_gradients now contains lambda functions.
        (a, lambda path_value: path_value),
        # local gradient is 1, so multiply path_value by 1.
        (b, lambda path_value: path_value)
        # local gradient is 1, so multiply path_value by 1.
    )
    return Variable(value, local_gradients)

def mul(a, b):
    value = a.value * b.value
    local_gradients = (
        (a, lambda path_value: path_value * b),
        # local gradient for a is b, so multiply path_value by b.
        (b, lambda path_value : path_value * a)
        # local gradient for b is a, so multiply path_value by a.
    )
    return Variable(value, local_gradients)

def div(a, b):
    value = a.value / b.value
    local_gradients = (
        (a, lambda path_value : path_value * ONE/b),
        (b, lambda path_value : path_value * NEG_ONE * a/(b*b))
    )
    return Variable(value, local_gradients)

def get_gradients(variable):
    """ Compute the first derivatives of `variable` 
    with respect to child variables.
    """
    gradients = defaultdict(lambda: Variable(0))
    
    def compute_gradients(variable, path_value):
        for child_variable, multiply_by_locgrad in variable.local_gradients:
            # "Multiply the edges of a path":
            value_of_path_to_child = multiply_by_locgrad(path_value)  # Now a function is used here.
            # "Add together the different paths":
            gradients[child_variable] += value_of_path_to_child
            # recurse through graph:
            compute_gradients(child_variable, value_of_path_to_child)
    
    compute_gradients(variable, path_value=ONE)  # Now path_value is a Variable.
    # (path_value=1 is from `variable` differentiated w.r.t. itself)
    return gradients

A bunch of examples:

"""
A 2nd derivative example.
y = x*x = x**2
y' = 2x (= 2*3 = 6)
y'' = 2
"""

x = Variable(3)
y = x * x

derivs_1 = get_gradients(y)
dy_by_dx = derivs_1[x]

print('y.value =', y.value)
print("The derivative of y with respect to x =", dy_by_dx.value)

derivs_2 = get_gradients(dy_by_dx)
dy_by_dx2 = derivs_2[x]
print("The 2nd derivative of y with respect to x =", dy_by_dx2.value)
y.value = 9
The derivative of y with respect to x = 6.0
The 2nd derivative of y with respect to x = 2.0
"""
2nd derivative of our original example.
d = a*(a + b)
d = a**2 + ab
dd_by_da = 2a + b (= 2*4 + 3 = 11)
dd_by_da2 = 2
"""

a = Variable(4)
b = Variable(3)
c = a + b # = 4 + 3 = 7
d = a * c # = 4 * 7 = 28

derivs_1 = get_gradients(d)
dd_by_da = derivs_1[a]

print('d.value =', d.value)
print("The partial derivative of d with respect to a =", dd_by_da.value)

derivs_2 = get_gradients(dd_by_da)
dd_by_da2 = derivs_2[a]
print("The 2nd partial derivative of d with respect to a =", dd_by_da2.value)
d.value = 28
The partial derivative of d with respect to a = 11.0
The 2nd partial derivative of d with respect to a = 2.0
"""
Another second derivative example, with division.
y = a*a/b
dy_by_da = 2*a/b (=2*3/7=0.8571...)
dy_by_da2 = 2/b (=2/7=0.2857...)
"""

a = Variable(3)
b = Variable(7)
y = a * a / b

derivs_1 = get_gradients(y)
dy_by_da = derivs_1[a]

print('y.value =', y.value)
print("The partial derivative of y with respect to a =", dy_by_da.value)

derivs_2 = get_gradients(dy_by_da)
dy_by_da2 = derivs_2[a]
print("The 2nd partial derivative of y with respect to a =", dy_by_da2.value)
y.value = 1.2857142857142858
The partial derivative of y with respect to a = 0.8571428571428571
The 2nd partial derivative of y with respect to a = 0.2857142857142857
"""
Nth derivative example
y = x**N
y' = N*x**(N-1) (=10*3**9=196830)
y'' = (N-1)*N*x**(N-2) (=9*10*3**8=590490)
y''' = (N-2)(N-1)*N*x**(N-3) (=8*9*10*3**7=1574640)
etc.
"""
from magic_timer import MagicTimer

N = 10
x = Variable(3)

y = ONE
for i in range(10):
    y *= x

print(f'At x={x.value}:\n')
    
target = y
for n in range(1,N+1):
    timer = MagicTimer()
    
    derivs_n = get_gradients(target)
    dy_by_dx_nth = derivs_n[x]
    target = dy_by_dx_nth
    
    print(f'The n={n} derivative of y w.r.t x is', dy_by_dx_nth.value)
    print('Elapsed time:', timer)
    print()
At x=3:

The n=1 derivative of y w.r.t x is 196830.0
Elapsed time: 693 microseconds

The n=2 derivative of y w.r.t x is 590490.0
Elapsed time: 6.3 milliseconds

The n=3 derivative of y w.r.t x is 1574640.0
Elapsed time: 13 milliseconds

The n=4 derivative of y w.r.t x is 3674160.0
Elapsed time: 360 milliseconds

---------------------------------------------------------------------------

RecursionError                            Traceback (most recent call last)
...
RecursionError: maximum recursion depth exceeded
"""
If we wanted very high order derivatives,
we could increase Python's recursion limit.
Alternatively, we could rewrite get_gradients,
so that the function doesn't recurse.
"""
import sys
print('Original recursion limit:', sys.getrecursionlimit())
sys.setrecursionlimit(10_000)
print('New recursion limit:', sys.getrecursionlimit())
Original recursion limit: 3000
New recursion limit: 10000

Coming soon: part two, where we look at how to vectorize our minimal framework more efficiency.

Check out this repo for a taste of what’s to come: https://github.com/sradc/SmallPebble.


References: