Automatic differentiation (AD) refers to a family of algorithms that can be used to compute derivatives of functions in a systematized way. AD has been popularized by a slew of popular machine learning libraries, like autograd, TensorFlow, and PyTorch. However, the details are often overlooked. This post attempts to explain some of the details and show simple implementations of the two main types of AD: forward-mode and reverse-mode.
In my opinion, the term “automatic differentiation” is a bit of a misnomer. Automatic implies that it’s a general-purpose rule for computing derivatives of arbitrary functions without human specification. In reality, the derivatives of low-level base functions must still be encoded somehow. A better name might be “systematized differentiation” or “programmatic differentiation”. In my mind, related techniques such as numerical differentiation are closer to being automatic, although these don’t compute exact derivatives. Nevertheless, the name has stuck.
We first review the basic calculus used by AD, and then we present the two modes of AD.
The basic ideas of AD rely heavily (almost exclusively) on the chain rule of calculus. As a brief refresher, suppose we have a composition of two functions $y = f(g(x))$. Recall that the chain rule gives us a method for computing the derivative with respect to $x$:
\[\frac{\partial y}{\partial x} = \frac{\partial f}{\partial g} \cdot \frac{\partial g}{\partial x}.\]We can generalize this to arbitrary compositions of functions,
\[f_1 \circ f_2 \circ \dots \circ f_k,\]in which case the chain rule yields
\[\frac{\partial y}{\partial x} = \frac{\partial f_1}{\partial f_2} \cdot\frac{\partial f_2}{\partial f_3} \dots \frac{\partial f_{k}}{\partial x}.\]At its core, one can think of AD as a clever software implementation of the chain rule. By decomposing complex derivatives into simpler partial derivatives, we can make it manageable to find extensible software that can compute derivatives of any code.
For this post, let’s consider the following simple function $f : \mathbb{R}^2 \rightarrow \mathbb{R}$ as a running example:
\begin{equation} y = f(x_1, x_2) = e^{2 x_1} + x_1 x_2^2 + \cos x_2. \tag{1} \label{eq:running_example} \end{equation}
We’ll use this function as a testbed throughout this post.
Another key step in algorithm-itizing the process of differentiation is defining a representation functions using a data structure that’s suitable to the task. Given that the chain rule is based on the idea of composing multiple functions together, a natural structure to consider is a directed graph. In particular, if we can decompose a function into a set of simpler elementary computations, we can perform each of these computations within a separate node on the graph, and feed outputs and inputs through the edges.
To see this, consider our running example in Equation \eqref{eq:running_example}. Notice that we can break this function down into a set of elementrary operations. First, we multiply $x_1$ by $2$, then we exponentiate it, then we square $x_2$, then we multiply it by $x_1$, and so on. If we denote each intermediate value using $w_1, w_2, \dots$, we can decompose the function as follows. Note that we treat $w_1$ and $w_2$ as the input data itself, $x_1, x_2$. Below, we list each step and the operation that occurs.
\begin{align} y &= e^{2 x_1} + x_1 x_2^2 + \cos x_2 \\ &= e^{2 w_1} + w_1 w_2^2 + \cos w_2 & \text{(Encode data)} \\ &= e^{w_3} + w_1 w_4 + w_5 & \text{(Multiply)} \\ &= w_6 + w_1 w_4 + w_5 & \text{(Exponentiate)} \\ &= w_6 + w_7 + w_5 & \text{(Multiply)} \\ &= w_8 + w_5 & \text{(Add)} \\ &= w_9. & \text{(Add)} \end{align}
In graph form, we can think of each $w_i$ as a vertex representing an elementary operation. This computation graph might look like the following:
Both forward- and reverse-mode AD make use of this computation graph to compute derivatives. We now have the basic ingredients needed for AD.
In forward-mode AD, we start at the “input” end of the graph, and work our way forward through the graph (left to right in the above graph). At each intermediate value, we compute a derivative along the way. In terms of a composition of functions, we start with the “innermost” function, and work outward.
In terms of a data structure, we need to store a derivative within each vertex in the graph. For node $w_i$, we use the notation $\dot{w_i}$ to denote its derivative with respect to the input, $\dot{w_i} = \partial w_i / \partial x$. The “dot” notation for derivatives is common in the physics community.
The basic algorithm for forward-mode AD is as follows. Suppose we’d like to take the derivative with respect to the first element of the input $x_1$ of a two-element input.
Note that as we traverse forward in the graph, we will need the derivatives of the children of each node so that we can properly use the chain rule on the fly.
To see this in practice, let’s come back to our running example. To compute the derivative with respect to $x_1$, we set $\dot{w_1} = 1, \dot{w_2} = 0$. We can then traverse forward in the graph as follows:
\begin{align} \dot{w_1} &= 1 \\ \dot{w_2} &= 0 \\ \dot{w_3} &= \frac{\partial w_3}{\partial w_1} \cdot \dot{w_1} \\ \dot{w_4} &= \frac{\partial w_4}{\partial w_2} \cdot \dot{w_2} \\ \dot{w_5} &= \frac{\partial w_5}{\partial w_2} \cdot \dot{w_2} \\ \dot{w_6} &= \dot{w_1} w_4 + w_1 \dot{w_4} \\ \dot{w_7} &= \dot{w_3} + \dot{w_6} + \dot{w_5}. \end{align}
Above, we only took the derivative with respect to $x_1$. To compute the derivative with respec to $x_2$, we would repeat the same process, but this time we would set the seeds as $\dot{w_1} = 0, \dot{w_2} = 1$. Notice that the computational complexity of forward-mode AD thus scales with the size of the input.
To implement forward-mode AD, we need a simple graph data structure. Here, I’ll show a simple (albeit not very extensible) way to do this. First, consider a Vertex class for a computation graph.
class Vertex():
def __init__(self, val=None, val_dot=None, children=[]):
self.val = val
self.val_dot = val_dot
self.children = children
def forward(self, x):
# Perform elementary operation
pass
def differentiate(self, x):
# Compute "dot" of this node
pass
Note that we need both a forward
function and a differentiate
function. The edges in the graph are determined by the children of each node. A Graph class will simply be made up of a collection of Vertex objects.
class Graph():
def __init__(self):
self.input = None
def forward(self, x):
pass
The forward
function in a Graph will call forward
iteratively on each Vertex.
We can then implement a class for each elementary operation that we need, and these classes will inherit from the generic Vertex class. For example, for multiplication by a scalar ($2$ in this case), we can consider the following class.
class TimesTwo(Vertex):
def __init__(self, val=None, val_dot=None, children=[]):
super().__init__()
self.val = val
self.val_dot = val_dot
self.children = children
def forward(self):
self.val = 2 * self.children[0].val
def differentiate(self):
self.val_dot = 2 * self.children[0].val_dot
See the Appendix for the full code of each elementary operation in our running example. By chaining these elementary operations together, we can create an entire graph. Implementing the forward
and differentiate
functions in the entire graph is then just a matter of looping through the vertices, and calling the appropriate function.
class SimpleFunction(Graph):
def __init__(self):
super().__init__()
self.input = [None, None]
W1Vertex = Data(val=self.input[0])
W2Vertex = Data(val=self.input[1])
W3Vertex = TimesTwo(children=[W1Vertex])
W4Vertex = Exp(children=[W3Vertex])
W5Vertex = Square(children=[W2Vertex])
W6Vertex = Multiply(children=[W1Vertex, W5Vertex])
W7Vertex = Cos(children=[W2Vertex])
W8Vertex = Add(children=[W4Vertex, W6Vertex])
W9Vertex = Add(children=[W7Vertex, W8Vertex])
self.vertex_list = [
W1Vertex,
W2Vertex,
W3Vertex,
W4Vertex,
W5Vertex,
W6Vertex,
W7Vertex,
W8Vertex,
W9Vertex
]
def insert_data_into_graph(self, data):
self.vertex_list[0].val = data[0]
self.vertex_list[1].val = data[1]
def forward(self, x):
self.insert_data_into_graph(x)
for v in self.vertex_list:
v.forward()
return self.vertex_list[-1].val
def differentiate(self, x):
self.forward(x)
self.vertex_list[0].val_dot = 1
self.vertex_list[1].val_dot = 0
for v in self.vertex_list[2:]:
v.differentiate()
return self.vertex_list[-1].val_dot
In reverse-mode AD, we traverse the computation graph in the other direction: from output to input (from the outermost function to the innermost). In this case, rather than computing the derivative of each node with respect to the input, we compute the derivative of the output with respect to each node: $\partial y/\partial w_i$.
Algorithmically, we store a quantity known as the “adjoint” for each node, where we typically use a bar notation,
\[\bar{w}_i = \frac{\partial y}{\partial w_i}.\]As we traverse backward through the graph, we can compute each individual adjoint using the previous adjoints computed for each node’s parents:
\[\bar{w}_i = \sum\limits_{p \in \text{parents}} \bar{w}_p \frac{\partial w_p}{\partial w_i}.\]In contrast to forward-mode AD, reverse-mode AD is able to compute the derivatives with respect to both $x_1$ and $x_2$ with just one pass. In general, the computational cost of reverse-mode AD will scale with the size of the output. In most settings, it’s more likely that the input is high-dimensional, while the output is one- or low-dimensional, so reverse-mode AD is often computationally preferable. Reverse-mode AD serves as the basis for the backpropagation algorithm, a popular appraoch to training neural networks.
In this case, the Python implementation looks very similar. However, instead of computing the derivative for each node, we compute and store the adjoint for each node’s children at each step. For example, for the multiply operation, we have the following class:
class Multiply(Vertex):
def __init__(self, val=None, val_dot=None, children=[]):
super().__init__()
self.val = val
self.val_dot = val_dot
self.children = children
def forward(self):
self.val = self.children[0].val * self.children[1].val
def differentiate(self):
self.children[0].adjoint += self.adjoint * self.children[1].val
self.children[1].adjoint += self.adjoint * self.children[0].val
Note that we initialize the adjoint to be $0$ for each node, so that we can increment it properly with each of its parents. The full class for the graph then looks like the following.
class SimpleFunction(Graph):
def __init__(self):
super().__init__()
self.input = [None, None]
W1Vertex = Data(val=self.input[0])
W2Vertex = Data(val=self.input[1])
W3Vertex = TimesTwo(children=[W1Vertex])
W4Vertex = Exp(children=[W3Vertex])
W5Vertex = Square(children=[W2Vertex])
W6Vertex = Multiply(children=[W1Vertex, W5Vertex])
W7Vertex = Cos(children=[W2Vertex])
W8Vertex = Add(children=[W4Vertex, W6Vertex])
W9Vertex = Add(children=[W7Vertex, W8Vertex])
self.vertex_list = [
W1Vertex,
W2Vertex,
W3Vertex,
W4Vertex,
W5Vertex,
W6Vertex,
W7Vertex,
W8Vertex,
W9Vertex
]
def insert_data_into_graph(self, data):
self.vertex_list[0].val = data[0]
self.vertex_list[1].val = data[1]
def forward(self, x):
self.insert_data_into_graph(x)
for v in self.vertex_list:
v.forward()
return self.vertex_list[-1].val
def differentiate(self, x):
self.forward(x)
self.vertex_list[-1].adjoint = 1
n_nodes = len(self.vertex_list)
ii = n_nodes - 1
while ii >= 0:
self.vertex_list[ii].differentiate()
ii -= 1
return [self.vertex_list[0].adjoint, self.vertex_list[1].adjoint]
It’s worth noting that the implementation I’ve shown above is purely for pedagogical purposes, and it’s nowhere near as general or useful as the implementations in packages like autograd or pytorch. While my code above manually forms a graph, these packages have much more sophisticated software for “tracing” through a function, automatically building a computation graph, and much more efficient linear algebra software. Furthermore, they have much more general classes for computing derivatives.
Full code for forward-mode AD:
import numpy as np
class Vertex():
def __init__(self, val=None, val_dot=None, children=[]):
self.val = val
self.val_dot = val_dot
self.children = children
def forward(self, x):
pass
def differentiate(self, x):
pass
class Graph():
def __init__(self):
self.input = None
def forward(self, x):
pass
class Data(Vertex):
def __init__(self, val=None, val_dot=None, children=[]):
super().__init__()
self.val = val
self.val_dot = val_dot
self.children = children
def forward(self):
pass
def differentiate(self):
self.val_dot = self.val_dot
class TimesTwo(Vertex):
def __init__(self, val=None, val_dot=None, children=[]):
super().__init__()
self.val = val
self.val_dot = val_dot
self.children = children
def forward(self):
self.val = 2 * self.children[0].val
def differentiate(self):
self.val_dot = 2 * self.children[0].val_dot
class Exp(Vertex):
def __init__(self, val=None, val_dot=None, children=[]):
super().__init__()
self.val = val
self.val_dot = val_dot
self.children = children
def forward(self):
self.val = np.exp(self.children[0].val)
def differentiate(self):
self.val_dot = self.children[0].val_dot * np.exp(self.children[0].val)
class Square(Vertex):
def __init__(self, val=None, val_dot=None, children=[]):
super().__init__()
self.val = val
self.val_dot = val_dot
self.children = children
def forward(self):
self.val = self.children[0].val**2
def differentiate(self):
self.val_dot = self.children[0].val_dot * 2 * self.children[0].val
class Multiply(Vertex):
def __init__(self, val=None, val_dot=None, children=[]):
super().__init__()
self.val = val
self.val_dot = val_dot
self.children = children
def forward(self):
self.val = self.children[0].val * self.children[1].val
def differentiate(self):
self.val_dot = self.children[0].val_dot * self.children[1].val + self.children[0].val * self.children[1].val_dot
class Cos(Vertex):
def __init__(self, val=None, val_dot=None, children=[]):
super().__init__()
self.val = val
self.val_dot = val_dot
self.children = children
def forward(self):
self.val = np.cos(self.children[0].val)
def differentiate(self):
self.val_dot = self.children[0].val_dot * -1 * np.sin(self.children[0].val)
class Add(Vertex):
def __init__(self, val=None, val_dot=None, children=[]):
super().__init__()
self.val = val
self.val_dot = val_dot
self.children = children
def forward(self):
self.val = self.children[0].val + self.children[1].val
def differentiate(self):
self.val_dot = self.children[0].val_dot + self.children[1].val_dot
class SimpleFunction(Graph):
def __init__(self):
super().__init__()
self.input = [None, None]
W1Vertex = Data(val=self.input[0])
W2Vertex = Data(val=self.input[1])
W3Vertex = TimesTwo(children=[W1Vertex])
W4Vertex = Exp(children=[W3Vertex])
W5Vertex = Square(children=[W2Vertex])
W6Vertex = Multiply(children=[W1Vertex, W5Vertex])
W7Vertex = Cos(children=[W2Vertex])
W8Vertex = Add(children=[W4Vertex, W6Vertex])
W9Vertex = Add(children=[W7Vertex, W8Vertex])
self.vertex_list = [
W1Vertex,
W2Vertex,
W3Vertex,
W4Vertex,
W5Vertex,
W6Vertex,
W7Vertex,
W8Vertex,
W9Vertex
]
def insert_data_into_graph(self, data):
self.vertex_list[0].val = data[0]
self.vertex_list[1].val = data[1]
def forward(self, x):
self.insert_data_into_graph(x)
for v in self.vertex_list:
v.forward()
return self.vertex_list[-1].val
def differentiate(self, x):
self.forward(x)
# Derivative wrt x1
self.vertex_list[0].val_dot = 1
self.vertex_list[1].val_dot = 0
for v in self.vertex_list[2:]:
v.differentiate()
dx1 = self.vertex_list[-1].val_dot
# Derivative wrt x2
self.vertex_list[0].val_dot = 0
self.vertex_list[1].val_dot = 1
for v in self.vertex_list[2:]:
v.differentiate()
dx2 = self.vertex_list[-1].val_dot
return [dx1, dx2]
if __name__ == "__main__":
function_graph = SimpleFunction()
x1, x2 = 1, 2
output = function_graph.forward([x1, x2])
numpy_truth = np.exp(2 * x1) + x1 * x2**2 + np.cos(x2)
assert output == numpy_truth
dx = function_graph.differentiate([x1, x2])
dx_true = [2 * np.exp(2 * x1) + x2**2, 2 * x1 * x2 - np.sin(x2)]
assert np.all(dx == dx_true)
Full code for reverse-mode AD:
import numpy as np
class Vertex():
def __init__(self, val=None, adjoint=0, children=[]):
self.val = val
self.adjoint = adjoint
self.children = children
def forward(self, x):
pass
def differentiate(self, x):
pass
class Graph():
def __init__(self):
self.input = None
def forward(self, x):
pass
class Data(Vertex):
def __init__(self, val=None, adjoint=0, children=[]):
super().__init__()
self.val = val
self.adjoint = adjoint
self.children = children
def forward(self):
pass
def differentiate(self):
pass
class TimesTwo(Vertex):
def __init__(self, val=None, adjoint=0, children=[]):
super().__init__()
self.val = val
self.adjoint = adjoint
self.children = children
def forward(self):
self.val = 2 * self.children[0].val
def differentiate(self):
self.children[0].adjoint += self.adjoint * 2
class Exp(Vertex):
def __init__(self, val=None, adjoint=0, children=[]):
super().__init__()
self.val = val
self.adjoint = adjoint
self.children = children
def forward(self):
self.val = np.exp(self.children[0].val)
def differentiate(self):
self.children[0].adjoint += self.adjoint * np.exp(self.children[0].val)
class Square(Vertex):
def __init__(self, val=None, adjoint=0, children=[]):
super().__init__()
self.val = val
self.adjoint = adjoint
self.children = children
def forward(self):
self.val = self.children[0].val**2
def differentiate(self):
self.children[0].adjoint += self.adjoint * 2 * self.children[0].val
class Multiply(Vertex):
def __init__(self, val=None, val_dot=None, children=[]):
super().__init__()
self.val = val
self.val_dot = val_dot
self.children = children
def forward(self):
self.val = self.children[0].val * self.children[1].val
def differentiate(self):
self.children[0].adjoint += self.adjoint * self.children[1].val
self.children[1].adjoint += self.adjoint * self.children[0].val
class Cos(Vertex):
def __init__(self, val=None, adjoint=0, children=[]):
super().__init__()
self.val = val
self.adjoint = adjoint
self.children = children
def forward(self):
self.val = np.cos(self.children[0].val)
def differentiate(self):
self.children[0].adjoint += self.adjoint * -1 * np.sin(self.children[0].val)
class Add(Vertex):
def __init__(self, val=None, val_dot=None, children=[]):
super().__init__()
self.val = val
self.val_dot = val_dot
self.children = children
def forward(self):
self.val = self.children[0].val + self.children[1].val
def differentiate(self):
self.children[0].adjoint += self.adjoint
self.children[1].adjoint += self.adjoint
class SimpleFunction(Graph):
def __init__(self):
super().__init__()
self.input = [None, None]
W1Vertex = Data(val=self.input[0])
W2Vertex = Data(val=self.input[1])
W3Vertex = TimesTwo(children=[W1Vertex])
W4Vertex = Exp(children=[W3Vertex])
W5Vertex = Square(children=[W2Vertex])
W6Vertex = Multiply(children=[W1Vertex, W5Vertex])
W7Vertex = Cos(children=[W2Vertex])
W8Vertex = Add(children=[W4Vertex, W6Vertex])
W9Vertex = Add(children=[W7Vertex, W8Vertex])
self.vertex_list = [
W1Vertex,
W2Vertex,
W3Vertex,
W4Vertex,
W5Vertex,
W6Vertex,
W7Vertex,
W8Vertex,
W9Vertex
]
def insert_data_into_graph(self, data):
self.vertex_list[0].val = data[0]
self.vertex_list[1].val = data[1]
def forward(self, x):
self.insert_data_into_graph(x)
for v in self.vertex_list:
v.forward()
return self.vertex_list[-1].val
def differentiate(self, x):
self.forward(x)
self.vertex_list[-1].adjoint = 1
n_nodes = len(self.vertex_list)
ii = n_nodes - 1
while ii >= 0:
self.vertex_list[ii].differentiate()
ii -= 1
return [self.vertex_list[0].adjoint, self.vertex_list[1].adjoint]
if __name__ == "__main__":
function_graph = SimpleFunction()
x1, x2 = 1, 2
output = function_graph.forward([x1, x2])
numpy_truth = np.exp(2 * x1) + x1 * x2**2 + np.cos(x2)
assert output == numpy_truth
dx = function_graph.differentiate([x1, x2])
dx_true = [2 * np.exp(2 * x1) + x2**2, 2 * x1 * x2 - np.sin(x2)]
assert np.all(dx == dx_true)