Automatic Differentiation

Paul Schrimpf

2024-10-02

Introduction

Derivatives

  • Needed for efficient equation solving and optimization
  • Can calculate automatically

Finite Differences

f(x) = sin(x)/(1.0+exp(x))

function dxfin(f,x)
  h = sqrt(eps(x))
  if abs(x) > 1
    h = h*abs(x)
  end
  (f(x+h) - f(x) )/ h
end

dxfin(f, 2.0)
-0.1450763155960872

Forward Automatic Differentiation

module Forward

struct Dual{T}
  v::T
  dv::T
end

Dual(x::T) where {T} = Dual(x, one(x))

import Base: +, sin, exp, *, /

function (+)(a::T, x::Dual{T}) where {T}
  Dual(a+x.v, x.dv)
end

function (*)(y::Dual, x::Dual)
  Dual(y.v*x.v, x.v*y.dv + x.dv*y.v)
end

function (/)(x::Dual, y::Dual)
  Dual(x.v/y.v, x.dv/y.v - x.v*y.dv/y.v^2)
end

exp(x::Dual) = Dual(exp(x.v), exp(x.v)*x.dv)
sin(x::Dual) = Dual(sin(x.v), cos(x.v)*x.dv)


function fdx(f,x)
  out=f(Dual(x))
  (out.v, out.dv)
end

end

Forward.fdx(f,2.0)
(0.10839091026481387, -0.14507631594729084)

Reverse Automatic Differentiation

  • compute \(f(x)\) in usual forward direction, keep track of each operation and intermediate value
  • compute derivative “backwards”
    • \(f(x) = g(h(x))\)
    • \(f'(x) = g'(h(x)) h'(x)\)
  • scales better for high dimensional \(x\)
  • implementation more complicated
    • Simple-ish example https://simeonschaub.github.io/ReverseModePluto/notebook.html

Julia AD Packages

ForwardDiff

ForwardDiff Example

using Distributions
function simulate_logit(observations, β)
  x = randn(observations, length(β))
  y = (x*β + rand(Logistic(), observations)) .>= 0.0
  return((y=y,x=x))
end

function logit_likelihood(β,y,x)
  p = map(xb -> cdf(Logistic(),xb), x*β)
  sum(log.(ifelse.(y, p, 1.0 .- p)))
end

n = 500
k = 3
β0 = ones(k)
(y,x) = simulate_logit(n,β0)

import ForwardDiff
∇L = ForwardDiff.gradient(b->logit_likelihood(b,y,x),β0)
3-element Vector{Float64}:
   3.76314337097665
 -11.968131147395866
   6.923502003034791

ForwardDiff Notes

  • For \(f: \mathbb{R}^n \to \mathbb{R}^m\), the computation scales with \(n\)
    • best for moderate \(n\)
  • Code must be generic
    • be careful when allocating arrays
function wontwork(x)
  y = zeros(eltype(x),size(x))
  for i  eachindex(x)
    y[i] += x[i]*i
  end
  return(sum(y))
end

function willwork(x)
  y = zero(x)
  for i  eachindex(x)
    y[i] += x[i]*i
  end
  return(sum(y))
end

betterstyle(x) = sum(v*i for (i,v) in enumerate(x))

Zygote

  • Zygote.jl
  • Does not allow mutating arrays
  • Quite mature, but possibly some bugs remain
  • Apparently hard to develop, unclear future

Zygote Example

import Zygote
using LinearAlgebra
@time ∇Lz =  Zygote.gradient(b->logit_likelihood(b,y,x),β0)[1]
norm(∇L - ∇Lz)
  3.883209 seconds (6.06 M allocations: 403.763 MiB, 5.50% gc time, 99.99% compilation time)
1.1648350771590544e-14

Enzyme

“Enzyme performs automatic differentiation (AD) of statically analyzable LLVM. It is highly-efficient and its ability to perform AD on optimized code allows Enzyme to meet or exceed the performance of state-of-the-art AD tools.”

import Enzyme
import Enzyme: Active, Duplicated, Const

db = zero0)
@time Enzyme.autodiff(Enzyme.ReverseWithPrimal,logit_likelihood, Active, Duplicated0,db), Const(y), Const(x))
db
┌ Warning: Using fallback BLAS replacements for (["dsymv_64_"]), performance may be degraded
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/y4cj1/src/utils.jl:59
 25.912938 seconds (25.55 M allocations: 1.714 GiB, 3.54% gc time, 100.00% compilation time)
3-element Vector{Float64}:
   3.76314337097664
 -11.96813114739587
   6.923502003034802

Enzyme Notes

  • Documentation is not suited to beginners
  • Does not work on all Julia code, but cases where it fails are not well documented. Calling Enzyme.API.runtimeActivity!(true) works around some errors.
  • Cryptic error messages. Enzyme operates on LLVM IR, and error messages often reference the point in the LLVM IR where the error occurred. Figuring out what Julia code the LLVM IR corresponds to is not easy.
    • These may be better now than last year when I first wrote this slide
Enzyme.API.runtimeActivity!(false)
f1(a,b) = sum(a.*b)
dima = 30000
a = ones(dima)
b = rand(dima)
da = zeros(dima)
@time Enzyme.autodiff(Enzyme.ReverseWithPrimal, f1, Duplicated(a,da),Const(b))
da

f3(a,b) = sum(a[i]*b[i] for i  eachindex(a))
da = zeros(dima)
Enzyme.autodiff(Enzyme.ReverseWithPrimal, f3, Duplicated(a,da),Const(b))
da

if (false) # will trigger enzyme error without runtimeactivity
  f2(a,b) = sum(a*b for (a,b)  zip(a,b))
  da = zeros(dima)
  @time Enzyme.autodiff(Enzyme.ReverseWithPrimal, f2, Duplicated(a,da), Const(b))
  da
end

Enzyme.API.runtimeActivity!(true)
f2(a,b) = sum(a*b for (a,b)  zip(a,b))
da = zeros(dima)
@time Enzyme.autodiff(Enzyme.ReverseWithPrimal, f2, Duplicated(a,da), Const(b))
da
  1.309099 seconds (2.36 M allocations: 167.310 MiB, 4.65% gc time, 99.95% compilation time)
  0.885099 seconds (2.06 M allocations: 148.183 MiB, 7.22% gc time, 99.94% compilation time)
30000-element Vector{Float64}:
 0.6604300486634319
 0.6270180550987297
 0.6299565612785084
 0.04494418448097137
 0.8856048079314317
 0.49066159256372843
 0.8628059721863847
 0.6945672822637726
 0.8878644919655609
 0.5158311069160438
 ⋮
 0.9388866128698892
 0.49862855571403253
 0.6131248043328932
 0.9804335362774208
 0.07423427954088635
 0.8459359730567572
 0.3172947744777822
 0.572499813003595
 0.01770995753210758

FiniteDiff

  • FiniteDiff computes finite difference gradients– always test that whatever automatic or manual derivatives you compute are close to the finite difference versions
  • use a package for finite differences to handle rounding error well

ChainRules

  • ChainRules
  • used by many AD packages to define the derivatives of various functions.
  • Useful if you want to define a custom derivative rule for a function.

DifferentiationInterface

import DifferentiationInterface as DI
DI.gradient(b->logit_likelihood(b,y,x), DI.AutoEnzyme(),β0)
┌ Warning: Using fallback BLAS replacements for (["dsymv_64_"]), performance may be degraded
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/y4cj1/src/utils.jl:59
3-element Vector{Float64}:
   3.76314337097664
 -11.96813114739587
   6.923502003034802
  • improve performance by reusing intermediate variables
backend = DI.AutoEnzyme()
dcache = DI.prepare_gradient(b->logit_likelihood(b,y,x), backend, β0)
grad = zero0)
DI.gradient!(b->logit_likelihood(b,y,x),grad, backend,β0 , dcache)
3-element Vector{Float64}:
   3.76314337097664
 -11.96813114739587
   6.923502003034802

Other Packages

Other Packages

ReverseDiff.jl

  • ReverseDiff.jl a tape based reverse mode package

  • Long lived and well tested

  • limitations. Importantly, code must be generic and mutation of arrays is not allowed.

Yota.jl

Tracker

Tracker is a tape based reverse mode package. It was the default autodiff package in Flux before being replaced by Zygote. No longer under active development.

Diffractor

Diffractor is automatic differentiation package in development. It was once hoped to be the future of AD in Julia, but has been delayed. It plans to have both forward and reverse mode, but only forward mode is available so far.

References