Andy Jones CS PhD student @ Princeton

Automatic differentiation

Automatic differentiation (AD) refers to a family of algorithms that can be used to compute derivatives of functions in a systematized way. AD has been popularized by a slew of popular machine learning libraries, like autograd, TensorFlow, and PyTorch. However, the details are often overlooked. This post attempts to explain some of the details and show simple implementations of the two main types of AD: forward-mode and reverse-mode.

In my opinion, the term “automatic differentiation” is a bit of a misnomer. Automatic implies that it’s a general-purpose rule for computing derivatives of arbitrary functions without human specification. In reality, the derivatives of low-level base functions must still be encoded somehow. A better name might be “systematized differentiation” or “programmatic differentiation”. In my mind, related techniques such as numerical differentiation are closer to being automatic, although these don’t compute exact derivatives. Nevertheless, the name has stuck.

We first review the basic calculus used by AD, and then we present the two modes of AD.

Chain rule

The basic ideas of AD rely heavily (almost exclusively) on the chain rule of calculus. As a brief refresher, suppose we have a composition of two functions $y = f(g(x))$. Recall that the chain rule gives us a method for computing the derivative with respect to $x$:

\[\frac{\partial y}{\partial x} = \frac{\partial f}{\partial g} \cdot \frac{\partial g}{\partial x}.\]

We can generalize this to arbitrary compositions of functions,

\[f_1 \circ f_2 \circ \dots \circ f_k,\]

in which case the chain rule yields

\[\frac{\partial y}{\partial x} = \frac{\partial f_1}{\partial f_2} \cdot\frac{\partial f_2}{\partial f_3} \dots \frac{\partial f_{k}}{\partial x}.\]

At its core, one can think of AD as a clever software implementation of the chain rule. By decomposing complex derivatives into simpler partial derivatives, we can make it manageable to find extensible software that can compute derivatives of any code.

Our running example

For this post, let’s consider the following simple function $f : \mathbb{R}^2 \rightarrow \mathbb{R}$ as a running example:

\begin{equation} y = f(x_1, x_2) = e^{2 x_1} + x_1 x_2^2 + \cos x_2. \tag{1} \label{eq:running_example} \end{equation}

We’ll use this function as a testbed throughout this post.

Representing functions as computation graphs

Another key step in algorithm-itizing the process of differentiation is defining a representation functions using a data structure that’s suitable to the task. Given that the chain rule is based on the idea of composing multiple functions together, a natural structure to consider is a directed graph. In particular, if we can decompose a function into a set of simpler elementary computations, we can perform each of these computations within a separate node on the graph, and feed outputs and inputs through the edges.

To see this, consider our running example in Equation \eqref{eq:running_example}. Notice that we can break this function down into a set of elementrary operations. First, we multiply $x_1$ by $2$, then we exponentiate it, then we square $x_2$, then we multiply it by $x_1$, and so on. If we denote each intermediate value using $w_1, w_2, \dots$, we can decompose the function as follows. Note that we treat $w_1$ and $w_2$ as the input data itself, $x_1, x_2$. Below, we list each step and the operation that occurs.

\begin{align} y &= e^{2 x_1} + x_1 x_2^2 + \cos x_2 \\ &= e^{2 w_1} + w_1 w_2^2 + \cos w_2 & \text{(Encode data)} \\ &= e^{w_3} + w_1 w_4 + w_5 & \text{(Multiply)} \\ &= w_6 + w_1 w_4 + w_5 & \text{(Exponentiate)} \\ &= w_6 + w_7 + w_5 & \text{(Multiply)} \\ &= w_8 + w_5 & \text{(Add)} \\ &= w_9. & \text{(Add)} \end{align}

In graph form, we can think of each $w_i$ as a vertex representing an elementary operation. This computation graph might look like the following:

A computation graph. Each arrow represents the flow of a numerical value into a parent node.

Both forward- and reverse-mode AD make use of this computation graph to compute derivatives. We now have the basic ingredients needed for AD.

Forward-mode AD

In forward-mode AD, we start at the “input” end of the graph, and work our way forward through the graph (left to right in the above graph). At each intermediate value, we compute a derivative along the way. In terms of a composition of functions, we start with the “innermost” function, and work outward.

In terms of a data structure, we need to store a derivative within each vertex in the graph. For node $w_i$, we use the notation $\dot{w_i}$ to denote its derivative with respect to the input, $\dot{w_i} = \partial w_i / \partial x$. The “dot” notation for derivatives is common in the physics community.

The basic algorithm for forward-mode AD is as follows. Suppose we’d like to take the derivative with respect to the first element of the input $x_1$ of a two-element input.

  1. Set the data values $w_1 = x_1, w_2 = x_2$.
  2. Set the “seeds” $\dot{w}_1 = 1, \dot{w}_2 = 0$.
  3. Until we’ve reached the output node:
    1. Compute the derivative $\dot{w}_i$ at each node $w_i$, traversing from children to parents.

Note that as we traverse forward in the graph, we will need the derivatives of the children of each node so that we can properly use the chain rule on the fly.

To see this in practice, let’s come back to our running example. To compute the derivative with respect to $x_1$, we set $\dot{w_1} = 1, \dot{w_2} = 0$. We can then traverse forward in the graph as follows:

\begin{align} \dot{w_1} &= 1 \\ \dot{w_2} &= 0 \\ \dot{w_3} &= \frac{\partial w_3}{\partial w_1} \cdot \dot{w_1} \\ \dot{w_4} &= \frac{\partial w_4}{\partial w_2} \cdot \dot{w_2} \\ \dot{w_5} &= \frac{\partial w_5}{\partial w_2} \cdot \dot{w_2} \\ \dot{w_6} &= \dot{w_1} w_4 + w_1 \dot{w_4} \\ \dot{w_7} &= \dot{w_3} + \dot{w_6} + \dot{w_5}. \end{align}

Forward-mode automatic differentiation. The derivative $\dot{w}_i$ is stored at each node as we traverse forward from input to output.

Above, we only took the derivative with respect to $x_1$. To compute the derivative with respec to $x_2$, we would repeat the same process, but this time we would set the seeds as $\dot{w_1} = 0, \dot{w_2} = 1$. Notice that the computational complexity of forward-mode AD thus scales with the size of the input.

Python example

To implement forward-mode AD, we need a simple graph data structure. Here, I’ll show a simple (albeit not very extensible) way to do this. First, consider a Vertex class for a computation graph.

class Vertex():

	def __init__(self, val=None, val_dot=None, children=[]):
		self.val = val
		self.val_dot = val_dot
		self.children = children

	def forward(self, x):
		# Perform elementary operation
		pass

	def differentiate(self, x):
		# Compute "dot" of this node
		pass

Note that we need both a forward function and a differentiate function. The edges in the graph are determined by the children of each node. A Graph class will simply be made up of a collection of Vertex objects.

class Graph():

	def __init__(self):
		self.input = None

	def forward(self, x):
		pass

The forward function in a Graph will call forward iteratively on each Vertex.

We can then implement a class for each elementary operation that we need, and these classes will inherit from the generic Vertex class. For example, for multiplication by a scalar ($2$ in this case), we can consider the following class.

class TimesTwo(Vertex):

	def __init__(self, val=None, val_dot=None, children=[]):
		super().__init__()
		self.val = val
		self.val_dot = val_dot
		self.children = children

	def forward(self):
		self.val = 2 * self.children[0].val

	def differentiate(self):
		self.val_dot = 2 * self.children[0].val_dot

See the Appendix for the full code of each elementary operation in our running example. By chaining these elementary operations together, we can create an entire graph. Implementing the forward and differentiate functions in the entire graph is then just a matter of looping through the vertices, and calling the appropriate function.

class SimpleFunction(Graph):

	def __init__(self):
		super().__init__()

		self.input = [None, None]
		W1Vertex = Data(val=self.input[0])
		W2Vertex = Data(val=self.input[1])
		W3Vertex = TimesTwo(children=[W1Vertex])
		W4Vertex = Exp(children=[W3Vertex])
		W5Vertex = Square(children=[W2Vertex])
		W6Vertex = Multiply(children=[W1Vertex, W5Vertex])
		W7Vertex = Cos(children=[W2Vertex])
		W8Vertex = Add(children=[W4Vertex, W6Vertex])
		W9Vertex = Add(children=[W7Vertex, W8Vertex])

		self.vertex_list = [
			W1Vertex,
			W2Vertex,
			W3Vertex,
			W4Vertex,
			W5Vertex,
			W6Vertex,
			W7Vertex,
			W8Vertex,
			W9Vertex
		]
	
	def insert_data_into_graph(self, data):
		self.vertex_list[0].val = data[0]
		self.vertex_list[1].val = data[1]

	def forward(self, x):

		self.insert_data_into_graph(x)
		
		for v in self.vertex_list:
			v.forward()

		return self.vertex_list[-1].val

	def differentiate(self, x):

		self.forward(x)

		self.vertex_list[0].val_dot = 1
		self.vertex_list[1].val_dot = 0
		for v in self.vertex_list[2:]:
			v.differentiate()

		return self.vertex_list[-1].val_dot

Reverse-mode AD

In reverse-mode AD, we traverse the computation graph in the other direction: from output to input (from the outermost function to the innermost). In this case, rather than computing the derivative of each node with respect to the input, we compute the derivative of the output with respect to each node: $\partial y/\partial w_i$.

Algorithmically, we store a quantity known as the “adjoint” for each node, where we typically use a bar notation,

\[\bar{w}_i = \frac{\partial y}{\partial w_i}.\]

As we traverse backward through the graph, we can compute each individual adjoint using the previous adjoints computed for each node’s parents:

\[\bar{w}_i = \sum\limits_{p \in \text{parents}} \bar{w}_p \frac{\partial w_p}{\partial w_i}.\]
Reverse-mode automatic differentiation. The adjoint $\bar{w}_i$ is stored at each node as we traverse backward from output to input.

In contrast to forward-mode AD, reverse-mode AD is able to compute the derivatives with respect to both $x_1$ and $x_2$ with just one pass. In general, the computational cost of reverse-mode AD will scale with the size of the output. In most settings, it’s more likely that the input is high-dimensional, while the output is one- or low-dimensional, so reverse-mode AD is often computationally preferable. Reverse-mode AD serves as the basis for the backpropagation algorithm, a popular appraoch to training neural networks.

In this case, the Python implementation looks very similar. However, instead of computing the derivative for each node, we compute and store the adjoint for each node’s children at each step. For example, for the multiply operation, we have the following class:

class Multiply(Vertex):

	def __init__(self, val=None, val_dot=None, children=[]):
		super().__init__()
		self.val = val
		self.val_dot = val_dot
		self.children = children

	def forward(self):
		self.val = self.children[0].val * self.children[1].val

	def differentiate(self):
		self.children[0].adjoint += self.adjoint * self.children[1].val
		self.children[1].adjoint += self.adjoint * self.children[0].val

Note that we initialize the adjoint to be $0$ for each node, so that we can increment it properly with each of its parents. The full class for the graph then looks like the following.

class SimpleFunction(Graph):

	def __init__(self):
		super().__init__()

		self.input = [None, None]
		W1Vertex = Data(val=self.input[0])
		W2Vertex = Data(val=self.input[1])
		W3Vertex = TimesTwo(children=[W1Vertex])
		W4Vertex = Exp(children=[W3Vertex])
		W5Vertex = Square(children=[W2Vertex])
		W6Vertex = Multiply(children=[W1Vertex, W5Vertex])
		W7Vertex = Cos(children=[W2Vertex])
		W8Vertex = Add(children=[W4Vertex, W6Vertex])
		W9Vertex = Add(children=[W7Vertex, W8Vertex])

		self.vertex_list = [
			W1Vertex,
			W2Vertex,
			W3Vertex,
			W4Vertex,
			W5Vertex,
			W6Vertex,
			W7Vertex,
			W8Vertex,
			W9Vertex
		]
	
	def insert_data_into_graph(self, data):
		self.vertex_list[0].val = data[0]
		self.vertex_list[1].val = data[1]

	def forward(self, x):

		self.insert_data_into_graph(x)
		
		for v in self.vertex_list:
			v.forward()

		return self.vertex_list[-1].val

	def differentiate(self, x):

		self.forward(x)

		self.vertex_list[-1].adjoint = 1

		n_nodes = len(self.vertex_list)

		ii = n_nodes - 1
		while ii >= 0:
			self.vertex_list[ii].differentiate()
			ii -= 1

		return [self.vertex_list[0].adjoint, self.vertex_list[1].adjoint]

Caveats

It’s worth noting that the implementation I’ve shown above is purely for pedagogical purposes, and it’s nowhere near as general or useful as the implementations in packages like autograd or pytorch. While my code above manually forms a graph, these packages have much more sophisticated software for “tracing” through a function, automatically building a computation graph, and much more efficient linear algebra software. Furthermore, they have much more general classes for computing derivatives.

References

Appendix

Full code for forward-mode AD:

import numpy as np

class Vertex():

	def __init__(self, val=None, val_dot=None, children=[]):
		self.val = val
		self.val_dot = val_dot
		self.children = children

	def forward(self, x):
		pass

	def differentiate(self, x):
		pass

class Graph():

	def __init__(self):
		self.input = None

	def forward(self, x):
		pass

class Data(Vertex):

	def __init__(self, val=None, val_dot=None, children=[]):
		super().__init__()
		self.val = val
		self.val_dot = val_dot
		self.children = children

	def forward(self):
		pass

	def differentiate(self):
		self.val_dot = self.val_dot


class TimesTwo(Vertex):

	def __init__(self, val=None, val_dot=None, children=[]):
		super().__init__()
		self.val = val
		self.val_dot = val_dot
		self.children = children

	def forward(self):
		self.val = 2 * self.children[0].val

	def differentiate(self):
		self.val_dot = 2 * self.children[0].val_dot

class Exp(Vertex):

	def __init__(self, val=None, val_dot=None, children=[]):
		super().__init__()
		self.val = val
		self.val_dot = val_dot
		self.children = children

	def forward(self):
		self.val = np.exp(self.children[0].val)

	def differentiate(self):
		self.val_dot = self.children[0].val_dot * np.exp(self.children[0].val)

class Square(Vertex):

	def __init__(self, val=None, val_dot=None, children=[]):
		super().__init__()
		self.val = val
		self.val_dot = val_dot
		self.children = children

	def forward(self):
		self.val = self.children[0].val**2

	def differentiate(self):
		self.val_dot = self.children[0].val_dot * 2 * self.children[0].val

class Multiply(Vertex):

	def __init__(self, val=None, val_dot=None, children=[]):
		super().__init__()
		self.val = val
		self.val_dot = val_dot
		self.children = children

	def forward(self):
		self.val = self.children[0].val * self.children[1].val

	def differentiate(self):
		self.val_dot = self.children[0].val_dot * self.children[1].val + self.children[0].val * self.children[1].val_dot

class Cos(Vertex):

	def __init__(self, val=None, val_dot=None, children=[]):
		super().__init__()
		self.val = val
		self.val_dot = val_dot
		self.children = children

	def forward(self):
		self.val = np.cos(self.children[0].val)

	def differentiate(self):
		self.val_dot = self.children[0].val_dot * -1 * np.sin(self.children[0].val)

class Add(Vertex):

	def __init__(self, val=None, val_dot=None, children=[]):
		super().__init__()
		self.val = val
		self.val_dot = val_dot
		self.children = children

	def forward(self):
		self.val = self.children[0].val + self.children[1].val

	def differentiate(self):
		self.val_dot = self.children[0].val_dot + self.children[1].val_dot


class SimpleFunction(Graph):

	def __init__(self):
		super().__init__()

		self.input = [None, None]
		W1Vertex = Data(val=self.input[0])
		W2Vertex = Data(val=self.input[1])
		W3Vertex = TimesTwo(children=[W1Vertex])
		W4Vertex = Exp(children=[W3Vertex])
		W5Vertex = Square(children=[W2Vertex])
		W6Vertex = Multiply(children=[W1Vertex, W5Vertex])
		W7Vertex = Cos(children=[W2Vertex])
		W8Vertex = Add(children=[W4Vertex, W6Vertex])
		W9Vertex = Add(children=[W7Vertex, W8Vertex])

		self.vertex_list = [
			W1Vertex,
			W2Vertex,
			W3Vertex,
			W4Vertex,
			W5Vertex,
			W6Vertex,
			W7Vertex,
			W8Vertex,
			W9Vertex
		]
	
	def insert_data_into_graph(self, data):
		self.vertex_list[0].val = data[0]
		self.vertex_list[1].val = data[1]

	def forward(self, x):

		self.insert_data_into_graph(x)
		
		for v in self.vertex_list:
			v.forward()

		return self.vertex_list[-1].val

	def differentiate(self, x):

		self.forward(x)

		# Derivative wrt x1
		self.vertex_list[0].val_dot = 1
		self.vertex_list[1].val_dot = 0
		for v in self.vertex_list[2:]:
			v.differentiate()

		dx1 = self.vertex_list[-1].val_dot

		# Derivative wrt x2
		self.vertex_list[0].val_dot = 0
		self.vertex_list[1].val_dot = 1
		for v in self.vertex_list[2:]:
			v.differentiate()

		dx2 = self.vertex_list[-1].val_dot

		return [dx1, dx2]


if __name__ == "__main__":
	function_graph = SimpleFunction()
	x1, x2 = 1, 2
	output = function_graph.forward([x1, x2])

	numpy_truth = np.exp(2 * x1) + x1 * x2**2 + np.cos(x2)
	assert output == numpy_truth

	dx = function_graph.differentiate([x1, x2])
	dx_true = [2 * np.exp(2 * x1) + x2**2, 2 * x1 * x2 - np.sin(x2)]

	assert np.all(dx == dx_true)


Full code for reverse-mode AD:

import numpy as np

class Vertex():

	def __init__(self, val=None, adjoint=0, children=[]):
		self.val = val
		self.adjoint = adjoint
		self.children = children

	def forward(self, x):
		pass

	def differentiate(self, x):
		pass

class Graph():

	def __init__(self):
		self.input = None

	def forward(self, x):
		pass

class Data(Vertex):

	def __init__(self, val=None, adjoint=0, children=[]):
		super().__init__()
		self.val = val
		self.adjoint = adjoint
		self.children = children

	def forward(self):
		pass

	def differentiate(self):
		pass


class TimesTwo(Vertex):

	def __init__(self, val=None, adjoint=0, children=[]):
		super().__init__()
		self.val = val
		self.adjoint = adjoint
		self.children = children

	def forward(self):
		self.val = 2 * self.children[0].val

	def differentiate(self):
		self.children[0].adjoint += self.adjoint * 2

class Exp(Vertex):

	def __init__(self, val=None, adjoint=0, children=[]):
		super().__init__()
		self.val = val
		self.adjoint = adjoint
		self.children = children

	def forward(self):
		self.val = np.exp(self.children[0].val)

	def differentiate(self):
		self.children[0].adjoint += self.adjoint * np.exp(self.children[0].val)

class Square(Vertex):

	def __init__(self, val=None, adjoint=0, children=[]):
		super().__init__()
		self.val = val
		self.adjoint = adjoint
		self.children = children

	def forward(self):
		self.val = self.children[0].val**2

	def differentiate(self):
		self.children[0].adjoint += self.adjoint * 2 * self.children[0].val

class Multiply(Vertex):

	def __init__(self, val=None, val_dot=None, children=[]):
		super().__init__()
		self.val = val
		self.val_dot = val_dot
		self.children = children

	def forward(self):
		self.val = self.children[0].val * self.children[1].val

	def differentiate(self):
		self.children[0].adjoint += self.adjoint * self.children[1].val
		self.children[1].adjoint += self.adjoint * self.children[0].val

class Cos(Vertex):

	def __init__(self, val=None, adjoint=0, children=[]):
		super().__init__()
		self.val = val
		self.adjoint = adjoint
		self.children = children

	def forward(self):
		self.val = np.cos(self.children[0].val)

	def differentiate(self):
		self.children[0].adjoint += self.adjoint * -1 * np.sin(self.children[0].val)

class Add(Vertex):

	def __init__(self, val=None, val_dot=None, children=[]):
		super().__init__()
		self.val = val
		self.val_dot = val_dot
		self.children = children

	def forward(self):
		self.val = self.children[0].val + self.children[1].val

	def differentiate(self):
		self.children[0].adjoint += self.adjoint
		self.children[1].adjoint += self.adjoint


class SimpleFunction(Graph):

	def __init__(self):
		super().__init__()

		self.input = [None, None]
		W1Vertex = Data(val=self.input[0])
		W2Vertex = Data(val=self.input[1])
		W3Vertex = TimesTwo(children=[W1Vertex])
		W4Vertex = Exp(children=[W3Vertex])
		W5Vertex = Square(children=[W2Vertex])
		W6Vertex = Multiply(children=[W1Vertex, W5Vertex])
		W7Vertex = Cos(children=[W2Vertex])
		W8Vertex = Add(children=[W4Vertex, W6Vertex])
		W9Vertex = Add(children=[W7Vertex, W8Vertex])

		self.vertex_list = [
			W1Vertex,
			W2Vertex,
			W3Vertex,
			W4Vertex,
			W5Vertex,
			W6Vertex,
			W7Vertex,
			W8Vertex,
			W9Vertex
		]
	
	def insert_data_into_graph(self, data):
		self.vertex_list[0].val = data[0]
		self.vertex_list[1].val = data[1]

	def forward(self, x):

		self.insert_data_into_graph(x)
		
		for v in self.vertex_list:
			v.forward()

		return self.vertex_list[-1].val

	def differentiate(self, x):

		self.forward(x)

		self.vertex_list[-1].adjoint = 1

		n_nodes = len(self.vertex_list)

		ii = n_nodes - 1
		while ii >= 0:
			self.vertex_list[ii].differentiate()
			ii -= 1

		return [self.vertex_list[0].adjoint, self.vertex_list[1].adjoint]


if __name__ == "__main__":
	function_graph = SimpleFunction()
	x1, x2 = 1, 2
	output = function_graph.forward([x1, x2])

	numpy_truth = np.exp(2 * x1) + x1 * x2**2 + np.cos(x2)
	assert output == numpy_truth

	dx = function_graph.differentiate([x1, x2])
	dx_true = [2 * np.exp(2 * x1) + x2**2, 2 * x1 * x2 - np.sin(x2)]

	assert np.all(dx == dx_true)