ECON622: Computational Economics with Data Science Applications

Key Terminology and Concepts in Machine Learning

Jesse Perla

University of British Columbia

Overview

Motivation

  • In this lecture we will introduce some key terminology and concepts in machine learning
  • We will map these concepts to terminology in economics and statistics
  • Finally, we will discuss Python Frameworks

Textbooks

“Depth” and Representations

  • Could approximate a function \(f(X)\) with a “shallow” approximation, e.g. polynomials of \(X\). Alternatively, nest functions \(h(\cdot)\) and \(\phi(\cdot)\) \[ f(X) \approx h(\phi(X)) \]
    • First, the \(\phi(X)\) will transform the state into something more amenable for the downstream task (e.g. prediction, classification, etc.)
    • Then the \(h(\cdot)\) maps that transformed state into the output.
  • Good \(\phi(\cdot)\) efficiently calculate \(h(\cdot)\). Often reusable for other tasks (e.g. \(f_2(X) \approx h_2(\phi(X))\))
  • For simple \(X\) we can design them (e.g., take means, logs, first-differences). But for rich data can we learn them from the data itself?

Relevant Categories of Machine Learning

  • Supervised Learning (e.g., Regression and Classification)
  • Unsupervised Learning (e.g., Clustering, Auto-encoders)
  • Semi-Supervised Learning (e.g., some observations)
  • Reinforcement Learning (e.g., policy/control)
  • Generative Models/Bayesian Methods (e.g., diffusions, probabilistic programming)
  • Instance-based learning (e.g., Kernel Methods) and Deep Learning are somewhat orthogonal

Key Terminology: Features, Labels, and Latents

  • Features are economists explanatory or independent variables. They have the key source of variation to make predictions and conduct counterfactuals
  • Labels correspond to economists observables or dependent variables
  • Latent Variables are unobserved variables, typically sources of heterogeneity or which may drive both the dependent and independent variables
  • Feature Engineering is the process of creating or selecting features from the data that are more useful for the task at hand

Key Concepts

  • 20th vs. 21st Century ML
  • Stochastic Gradients and Auto-Differentiation
  • Implicit and Explicit Regularization
  • Inductive/Implicit Bias
  • Generalization
  • Overfitting and the Bias-Variance Tradeoff
  • Test vs. Train vs. Validation Set
  • Hyperparameter Optimization
  • Representation Learning
  • Transfer Learning

Python

Why Python?

  • For “modern” ML: all the well-supported frameworks are in Python
  • In particular, auto-differentiation is central to many ML algorithms
  • Why should you avoid Julia/Matlab/R in these cases?
    • Poor AD, especially for reverse-mode
    • Network effects. Very few higher level packages for ML pipeline
    • But Julia dominates for many ML topics (e.g. ODEs) and R is outstanding for classic ML
  • Should you use Python for more things?
    • Maybe, but it is limited and can be slow unless you jump through hoops
    • Personally, if I have algorithms but no need for AD or particular packages, Julia is a much better language and less frustrating

There is No Such Thing as “Python”!

  • Many incompatible wrappers around C++ for numerical methods
  • Numpy/Scipy is the baseline (a common API)
  • Pytorch
  • JAX
  • Ones to avoid
    • Tensorflow, common in industry but old
    • Numba (for me, reasonable people disagree)

Pytorch

  • In recent years, the most flexible and popular ML framework for researchers
  • Key features:
    • Most of the code is for auto-differentiation/GPUs
    • JIT/etc. for GPU and fast kernels for deep learning
    • Neural Network libraries and utilities
    • A good subset of numpy
    • Utilities for ML pipelines optimization/etc.

Pytorch Key Downsides

  • Not really for general purpose programming
    • Intended for making auto-differentiation of neural networks easy, and updating gradients for solvers
    • May be very slow for simple things or ones which don’t involve high-order AD
  • Won’t always have packages you need for general code, and compatibility is ugly

JAX

  • Compiler that enables layered program transformations
    1. jit compiler to XLA, including accelerators (e.g. GPUs)
    2. grad Auto-differentiation
    3. vmap vectorization
    4. Flexibility to add more transformations
  • JAX PyTrees provide a nested tree structure for compiler passes
  • Closer to being a full JIT for general code than pytorch
  • For ML, not full-featured like pytorch. Need to shop for other libraries

JAX Key Downsides

  • Tough to trust Google, especially since it is a research project
    • Too ingrained in DeepMind research to disappear, but might have intermittent support
  • Different operating system support is limited, but they are making progress
  • Subset of python. Can’t really use loops, etc. Functional-style programming
    • Much more restrictive than it seems, and far more restrictive than pytorch

Python Ecosystem

Environments

  • Hate it, but nevertheless install conda
  • Always use a virtual environment to hate conda slightly less
  • Keep dependencies in requirements.txt for pip, conda environment.yml if necessary
  • To create, activate, and install packages
conda create -n econ622 python=3.11
conda activate econ622
pip install -r requirements.txt
  • In vscode, you can go > Python: Select Interpreter, choose econ622 and it will automatically activate

Development Environment

  • pip/anaconda with virtual environments
  • Use VSCode for debugging, testing, etc.
  • Github Copilot. Install Github Copilot Chat
  • Format with black. See here for setup

Baseline, Safe Packages to Use

Numpy/Scipy Basics

General Tools for ML Pipelines

  • Logging/visualization: Weights and Biases
    • Sign up for an account! Built in hyperparameter optimization tools
  • CLI useful for many pipelines and HPO. See here
  • For more end-to-end frameworks for deep-learning
    • Pytorch Lightning is extremely easy and flexble, eliminating a lot of boilerplate for CLI, optimizers, GPUs, etc.
    • Keras is a higher-level framework for deep learning. Traditionally tensorflow, but now many. Also FastAI
  • HuggingFace is a great resource for NLP and transformers
  • Optuna is a great hyperparameter optimization framework, etc.

JAX Ecosystem

Examples of Core Transformations

From JAX quickstart

Builtin composable transformations: jit, grad and vmap

import jax
import jax.numpy as jnp
import numpy as np
from jax import grad, jit, vmap, random

Compiling with jit

def selu(x, alpha=1.67, lmbda=1.05):
  return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)
key = random.PRNGKey(0)  
x = random.normal(key, (1000000,))
%timeit selu(x).block_until_ready()
selu_jit = jit(selu)
%timeit selu_jit(x).block_until_ready()
2.38 ms ± 65.9 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
602 μs ± 68.5 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

Convenience Decorators for jit

  • Convenience python decorator @jit
@jit
def selu(x, alpha=1.67, lmbda=1.05):
  return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)
%timeit selu(x).block_until_ready()
690 μs ± 23.1 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

Differentiation with grad

def sum_logistic(x):
  return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))

x_small = jnp.arange(3.)
derivative_fn = grad(sum_logistic)
print(derivative_fn(x_small))
[0.25       0.19661197 0.10499357]

Manual “Batching”/Vectorization

Common to run the same function along one dimension of an array

mat = random.normal(key, (150, 100))
batched_x = random.normal(key, (10, 100))

def f(v):
  return jnp.dot(mat, v)
def naively_batched_f(v_batched):
  return jnp.stack([f(v) for v in v_batched])
%timeit naively_batched_f(batched_x).block_until_ready()  
831 μs ± 24.2 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

Using vmap

The vmap applies across a dimension

@jit
def vmap_batched_f(v_batched):
  return vmap(f)(v_batched)

print('Auto-vectorized with vmap')
%timeit vmap_batched_f(batched_x).block_until_ready()
Auto-vectorized with vmap
47.2 μs ± 729 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

More vmap

Can fix dimensions with in_axes

def f(a, x, y):
  return a * x + y
a = 2.0
x = jnp.arange(5.)
y = jnp.arange(5.)
vmap(f, in_axes=(None, 0, 0))(a, x, y)
Array([ 0.,  3.,  6.,  9., 12.], dtype=float32)

Save vmap functions

Can fix dimensions with in_axes

@jax.jit
def f(a, x, y):
  return a * x + y
a = 2.0
x = jnp.arange(5.)
y = jnp.arange(5.)
f_batched = vmap(f, in_axes=(None, 0, 0))
f_batched(a, x, y)
Array([ 0.,  3.,  6.,  9., 12.], dtype=float32)

Key JAX Neural Network Libraries/Frameworks

  • Neural Network Libraries
    • Flax NNX
      • NNX is the new Flax API, Linen the older one
      • Has momentum, supported by google (for now)
    • Equinox
      • General, not just neural networks. Similar to NNX
    • Keras supports JAX (as well as PyTorch, TF, etc.)

Other ML-oriented Packages

  • Tough to keep up, see Awesome JAX
  • Optax for ML-style optimization
  • Checkpointing and serialization: obax

More Scientific Computing in JAX

  • jax.scipy which is a subset of scipy
  • Nonlinear Systems/Least Squares: Optimistix
  • Linear Systems of Equations: Lineax
  • Matrix-free operators for iterative solvers: COLA
  • Differential Equations: diffrax
  • More general optimization and solvers: JAXopt
  • Interpolation: interpax

JAX Challenges

  • Basically only pure functional programming
    • No “mutation” of vectors
    • Loops/conditionals are tough
    • Rules for what is jitable are tricky
  • See JAX - The Sharp Bits
  • May not be faster on CPUs or for “normal” things
  • Debugging

PyTrees

f = lambda x, y: jnp.vdot(x, y)
X = jnp.array([[1.0, 2.0],
               [3.0, 4.0]])
y = jnp.array([3.0, 4.0])
print(f(X[0], y))
print(f(X[1], y))

mv = vmap(f, in_axes = (
  0, # broadcast over 1st index of first argument
  None # don't broadcast over anything of second argument
  ), out_axes=0)
print(mv(X, y))
11.0
25.0
[11. 25.]

PyTree Example 1

The in_axes can match more complicated structures

dct = {'a': 0., 'b': jnp.arange(5.)}
def foo(dct, x):
 return dct['a'] + dct['b'] + x
# axes must match shape of the PyTree
x = 1.
out = vmap(foo, in_axes=(
  {'a': None, 'b': 0}, #broadcast over the 'b'
  None # no broadcasting over the "x"
  ))(dct, x)
# example now: {'a': 0, 'b': 0} etc.
print(out)
[1. 2. 3. 4. 5.]

PyTree Example 2

dct = {'a': jnp.array([3.0, 5.0]), 'b': jnp.array([2.0, 4.0])}
def foo2(dct, x):
 return dct['a'] + dct['b'] + x
# axes must match shape of the PyTree
x = 1.
out = vmap(foo2, in_axes=(
  {'a': 0, 'b': 0}, #broadcast over the 'a' and 'b'
  None # no broadcasting over the "x"
  ))(dct, x)
# example now: {'a': 3.0, 'b': 2.0} etc.
print(out)
[ 6. 10.]

PyTree Example 3

dct = {'a': jnp.array([3.0, 5.0]), 'b': jnp.arange(5.)}
def foo3(dct, x):
 return dct['a'][0] * dct['a'][1] + dct['b'] + x
# axes must match shape of the PyTree
out = vmap(foo3, in_axes=(
  {'a': None, 'b': 0}, #broadcast over the 'b'
  None # no broadcasting over the "x"
  ))(dct, x)
# example now: {'a': [3.0, 5.0], 'b': 0} etc.
print(out)
[16. 17. 18. 19. 20.]