Key Terminology and Concepts in Machine Learning
jit
compiler to XLA, including accelerators (e.g. GPUs)grad
Auto-differentiationvmap
vectorizationrequirements.txt
for pip
, conda environment.yml
if necessary> Python: Select Interpreter
, choose econ622
and it will automatically activateblack
. See here for setupFrom JAX quickstart
Builtin composable transformations: jit
, grad
and vmap
jit
def selu(x, alpha=1.67, lmbda=1.05):
return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)
key = random.PRNGKey(0)
x = random.normal(key, (1000000,))
%timeit selu(x).block_until_ready()
selu_jit = jit(selu)
%timeit selu_jit(x).block_until_ready()
2.38 ms ± 65.9 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
602 μs ± 68.5 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
jit
@jit
grad
Common to run the same function along one dimension of an array
mat = random.normal(key, (150, 100))
batched_x = random.normal(key, (10, 100))
def f(v):
return jnp.dot(mat, v)
def naively_batched_f(v_batched):
return jnp.stack([f(v) for v in v_batched])
%timeit naively_batched_f(batched_x).block_until_ready()
831 μs ± 24.2 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
vmap
The vmap applies across a dimension
vmap
Can fix dimensions with in_axes
vmap
functionsCan fix dimensions with in_axes
jit
able are trickyf = lambda x, y: jnp.vdot(x, y)
X = jnp.array([[1.0, 2.0],
[3.0, 4.0]])
y = jnp.array([3.0, 4.0])
print(f(X[0], y))
print(f(X[1], y))
mv = vmap(f, in_axes = (
0, # broadcast over 1st index of first argument
None # don't broadcast over anything of second argument
), out_axes=0)
print(mv(X, y))
11.0
25.0
[11. 25.]
The in_axes
can match more complicated structures
dct = {'a': 0., 'b': jnp.arange(5.)}
def foo(dct, x):
return dct['a'] + dct['b'] + x
# axes must match shape of the PyTree
x = 1.
out = vmap(foo, in_axes=(
{'a': None, 'b': 0}, #broadcast over the 'b'
None # no broadcasting over the "x"
))(dct, x)
# example now: {'a': 0, 'b': 0} etc.
print(out)
[1. 2. 3. 4. 5.]
dct = {'a': jnp.array([3.0, 5.0]), 'b': jnp.array([2.0, 4.0])}
def foo2(dct, x):
return dct['a'] + dct['b'] + x
# axes must match shape of the PyTree
x = 1.
out = vmap(foo2, in_axes=(
{'a': 0, 'b': 0}, #broadcast over the 'a' and 'b'
None # no broadcasting over the "x"
))(dct, x)
# example now: {'a': 3.0, 'b': 2.0} etc.
print(out)
[ 6. 10.]
dct = {'a': jnp.array([3.0, 5.0]), 'b': jnp.arange(5.)}
def foo3(dct, x):
return dct['a'][0] * dct['a'][1] + dct['b'] + x
# axes must match shape of the PyTree
out = vmap(foo3, in_axes=(
{'a': None, 'b': 0}, #broadcast over the 'b'
None # no broadcasting over the "x"
))(dct, x)
# example now: {'a': [3.0, 5.0], 'b': 0} etc.
print(out)
[16. 17. 18. 19. 20.]