Symbolic, Numerical, and Automatic Differentiation
A few general types of differentiation
Numerical Differentiation (i.e., finite differences)
Symbolic Differentiation (i.e., chain rule and simplify subexpressions by hand)
Automatic Differentiation (i.e., execute chain rule on computer)
Sparse Differentiation (i.e., use one of the above to calculate directional derivatives, potentially filling in sparse Jacobians with fewer passes)
\[ \frac{\partial f(x)}{\partial x_i} \approx \frac{f(x + \epsilon e_i) - f(x)}{\epsilon} \]
\[ f'(x) \approx \frac{-f(x-2\epsilon) + 8f(x-\epsilon) - 8f(x+\epsilon) + f(x+2\epsilon)}{12\epsilon} \]
\[ f'(x) = g'(x) f_1(g(x), h(g(x))) + g'(x) h'(g(x)) f_2(g(x), h(g(x))) \]
Auto-differentiation/differentiable programming works on “computer programs”. i.e., computational graphs are just functions
Finally: many frameworks will compile the resulting sequence of operations to be efficient on a GPU since this is so central to deep learning performance
Use the standard basis vectors \(e_1, e_2\) and calculate \(\mathcal{A}(e_1), \mathcal{A}(e_2)\)
Use the standard basis vectors \(e_1, e_2, e_3\) (now of \(\mathbb{R}^3\)) and calculate \(\mathcal{A}^{\top}(e_1), \mathcal{A}^{\top}(e_2), \mathcal{A}^{\top}(e_3)\)
Denote the operator, linearized around \(x\), and applied to \(v\in\mathbb{N}\) as
\[ (x, v) \mapsto \partial f(x)[v] \in \mathbb{R}^M \]
JAX (and others) will take an \(f\) and an \(x\) and compile a new function from \(\mathbb{R}^N\) to \(\mathbb{R}^M\) that calculates \(\partial f(x)[v]\)
\[ \partial f(x)^{\top} : \mathbb{R}^M \to \mathbb{R}^N \]
Let \(f : \mathbb{R}^2 \to \mathbb{R}^2\) be defined as
\[ f(x) \equiv \begin{bmatrix} x_1^2 + x_2^2 \\ x_1 x_2 \end{bmatrix} \]
Then
\[ \nabla f(x) \equiv \begin{bmatrix} 2 x_1 & 2 x_2 \\ x_2 & x_1 \end{bmatrix} \]
Let \(v = \begin{bmatrix} 1 & 0 \end{bmatrix}^{\top}\), i.e. the \(e_1\) in the standard basis then
\[ \partial f(x)[v] = \nabla f(x) \cdot \begin{bmatrix} 1 \\ 0 \end{bmatrix} = \begin{bmatrix} 2 x_1 \\ x_2 \end{bmatrix} \]
\[ \partial f(x)^{\top}[u] = \begin{bmatrix} 1 & 0 \end{bmatrix} \cdot \nabla f(x) = \begin{bmatrix} 2 x_1 & 2 x_2 \end{bmatrix} \]
\[ \partial f(x) = \partial c(b(a(x))) \circ \partial b(a(x)) \circ \partial a(x) \]
\[ \partial f(x)[v] = \partial c(b(a(x))) \left[ \partial b(a(x))[\partial a(x)[v]] \right] \]
\[ \partial f(x)[v] = \partial c(b(a(x))) \left[ \partial b(a(x))[\partial a(x)[v]] \right] \]
Calculation order inside out, recursively finding linearization points:
Conveniently follows calculating “primal” calculation. Many ways to do it (e.g. overloading, duals)
Can calculate the “primal” and the “push-forward” at the same time
\[ \partial f(x) = \partial c(b(a(x))) \circ \partial b(a(x)) \circ \partial a(x) \]
\[ \partial f(x)^{\top} = \partial a(x)^{\top} \circ \partial b(a(x))^{\top} \circ \partial c(b(a(x)))^{\top} \]
\[ \partial f(x)^{\top}[u] = \partial a(x)^{\top} \left[ \partial b(a(x))^{\top} \left[ \partial c(b(a(x)))^{\top}[u] \right] \right] \]
\[ \partial f(x)^{\top}[u] = \partial a(x)^{\top} \left[ \partial b(a(x))^{\top} \left[ \partial c(b(a(x)))^{\top}[u] \right] \right] \]
.backward()
requires_grad=True
.backward()
x = torch.tensor(2.0, requires_grad=True)
# Trace computations for the "forward" pass
y = torch.tanh(x)
# Do the "backward" pass for Reverse-mode AD
y.backward()
print(x.grad)
def f(x, y):
return x**3 + 2 * y[0]**2 - 3 * y[1] + 1
x = torch.tensor(1.0, requires_grad=True)
y = torch.tensor([2.0, 3.0],
requires_grad=True)
z = f(x, y)
z.backward()
print(x.grad, y.grad)
tensor(0.0707)
tensor(3.) tensor([ 8., -3.])
grad
is \(\mathbb{R}^N \to \mathbb{R}\) reverse-diff) as well as lower-level functions to directly use jvp
, vjp
, and hessian-vector productsjax.config.update('jax_enable_x64', True)
for 64bit precision (default is 32bit)grad
is the high-level reverse-mode AD functiongrad_tanh = grad(jnp.tanh)
print(grad_tanh(2.0))
grad_tanh_jit = jit(grad_tanh)
print(grad_tanh_jit(2.0))
def f(x):
return x**3 + 2 * x**2 - 3 * x + 1
print(grad(f)(1.0))
@jit
def f2(x):
return x**3 + 2 * x**2 - 3 * x + 1
grad(f2)(1.0)
0.070650816
0.070650816
4.0
Array(4., dtype=float32, weak_type=True)
def sigmoid(x):
return 0.5 * (jnp.tanh(x / 2) + 1)
def predict(W, b, inputs):
return sigmoid(jnp.dot(inputs, W) + b)
inputs = jnp.array([[0.52, 1.12, 0.77],
[0.88, -1.08, 0.15],
[0.52, 0.06, -1.30],
[0.74, -2.49, 1.39]])
targets = jnp.array([True, True, False, True])
key, W_key, b_key = random.split(key, 3)
W = random.normal(W_key, (3,))
b = random.normal(b_key, ())
f = lambda W: predict(W, b, inputs)
key, subkey = random.split(key)
v = random.normal(subkey, W.shape)
# Push forward
y, u = jvp(f, (W,), (v,))
print((y, u))
(Array([0.13262254, 0.952067 , 0.6249393 , 0.9980987 ], dtype=float32), Array([ 0.01893247, 0.06917656, -0.00554809, 0.00501772], dtype=float32))
def loss2(params_dict):
preds = predict(params_dict['W'], params_dict['b'], inputs)
label_probs = preds * targets + (1 - preds) * (1 - targets)
return -jnp.sum(jnp.log(label_probs))
params = {'W': W, 'b': b}
print(grad(loss2)(params))
{'W': Array([-0.16965583, -0.8774644 , -1.4901346 ], dtype=float32), 'b': Array(-0.29227245, dtype=float32)}
For a JVP or VJP, we first need to calculate the \(f(x)\)
Often madness to descend recursively into primal calculations
\[ \partial f(x)[v] = -\sin(x) \cdot v \]
AD systems all have a library of these rules, and typically a way to create new ones for “custom” rules for complicated functions
@jax.custom_jvp
def f(x, y):
return jnp.sin(x) * y
@f.defjvp
def f_jvp(primals, tangents):
x, y = primals
x_dot, y_dot = tangents
primal_out = f(x, y)
tangent_out = jnp.cos(x) * x_dot * y + jnp.sin(x) * y_dot
return primal_out, tangent_out
print(f(2., 3.))
y, y_dot = jvp(f, (2., 3.), (1., 0.)) # perturb x, not y
print(y, y_dot)
2.7278922
2.7278922 -1.2484405
\[ \begin{aligned} I &= C A\\ 0 &= \partial C A + C \partial A \\ 0 &= \partial C A C + C (\partial A) C \\ 0 &= \partial C + C (\partial A) C \\ \partial C &= -C (\partial A) C \\ \end{aligned} \]
Solve primal problem \(z^*(a) = f(a, z^*(a))\) for \(z^*(a)\) using Anderson iteration, Newton, etc. fixing \(a\). Use implicit function theorem at \(z^* \equiv z^*(a_0)\) \[ \frac{\partial z^*(a)}{\partial a} = \left[ I - \frac{\partial f(a, z^*)}{\partial z} \right]^{-1} \frac{\partial f(a, z^*)}{\partial a}. \]
For JVP: \((a, v) \mapsto \frac{\partial z^*(a)}{\partial a}v\)
\[ \frac{\partial z^*(a)}{\partial a}\cdot v = \left[ I - \frac{\partial f(a,z^*)}{\partial z} \right]^{-1} \frac{\partial f(a, z^*)}{\partial a}\cdot v \]
from jaxopt import Bisection
@jax.jit
def F(x, factor):
return factor * x ** 3 - x - 2
def root(factor):
bisec = Bisection(optimality_fun=F, lower=1, upper=2)
return bisec.run(factor=factor).params
# Derivative of root with respect to factor at 2.0.
print(grad(root)(2.0))
-0.22139914