Generalization, Deep Learning, and Representations
\[ R(f, p^*) \equiv \mathbb{E}_{p^*(x,y)}\left[\ell(f, x, y)\right] \]
\[ f^{**} = \arg\min_{f \in \mathcal{F}} R(f, p^*) \]
\[ \min_{f \in \mathcal{F}} R(f, p^*) \equiv \min_{f \in \mathcal{F}} \mathbb{E}_{p^*(x)}\left[\ell(f, x)\right] = 0 \]
\[ R(f, \mathcal{D}) \equiv \frac{1}{N}\sum_{n=1}^N \ell(f, x_n, y_n) = R(f, p_{\mathcal{D}}) \]
\[ f^*_{\mathcal{D}} = \arg\min_{f \in \mathcal{H}} R(f, \mathcal{D}) = \arg\min_{f \in \mathcal{H}}\frac{1}{N}\sum_{n=1}^N \ell(f, x_n, y_n) \]
\[ \arg\min_{\theta \in \Theta} \frac{1}{N}\sum_{n=1}^N \ell(f_{\theta}, x_n, y_n) \]
\[ f^* \equiv \arg\min_{f \in \mathcal{H}} R(f, p^*) \]
\[ \varepsilon_{app}(\mathcal{H}) \equiv R(f^*, p^*) - R(f^{**}, p^*) \]
\[ \varepsilon_{est}(\mathcal{H}) \equiv \mathbb{E}_{\mathcal{D} \sim p^*}\left[R(f^*_{\mathcal{D}}, \mathcal{D}) - R(f^*, p^*)\right] \]
\[ \mathbb{E}_{\mathcal{D} \sim p^*}\left[\min_{f \in \mathcal{H}} R(f, \mathcal{D}) - \min_{f \in \mathcal{F}} R(f, p^*)\right] = \varepsilon_{app}(\mathcal{H}) + \varepsilon_{est}(\mathcal{H}) \]
To abstract from this manual process, we can think of instead taking the raw data \(x\) and transforming it into a representation \(z\in \mathcal{Z}\) with \(g : \mathcal{X} \to \mathcal{Z}\)
Then, instead of finding a \(f : \mathcal{X} \to \mathcal{Y}\), we can find a \(\tilde{f} : \mathcal{Z} \to \mathcal{Y}\) and in our loss use \(\ell(\tilde{f}\circ g, x, y)\)
\[ \tilde{f}^*_{\mathcal{D}} = \arg\min_{\tilde{f} \in \mathcal{H}}\frac{1}{N}\sum_{n=1}^N \ell(\tilde{f} \circ g, x_n, y_n) \]
Our approximation class \(\mathcal{H}\) then changes - for better or worse.
Use economic intuition and problem specific knowledge to design \(\mathcal{H}\)
For example, you can approximation function \(f : \mathbb{R}^N \to \mathbb{R}\) which are symmetric in arguments (i.e. permutation invariance) with
\[ f(X) = \rho\left(\frac{1}{N}\sum_{x\in X} \phi(x)\right) \]
See Probabilistic Symmetries and Invariant Neural Networks or Exploiting Symmetry in High Dimensional Dynamic Programming
\[ \min_{\theta_e, \theta_d} \mathbb{E}_{p^*(x)} (h(g(x;\theta_e);\theta_d) - x)^2 + \text{regularizer} \]
MyLinear
case for a linear function without an affine term (i.e., no “bias”)nnx.Param
nnx.Module
MyLinear
it finds the nnx.Param
.nnx.grad
does not perturb out_size
, etc.State({
'kernel': VariableState(
type=Param,
value=Array([[-1.5557193, -0.8495713, -1.1160917]], dtype=float32)
)
})
nnx.Module
nnx.Linear
instead, construct a simple NNclass MyMLP(nnx.Module):
def __init__(self, din, dout, width: int, *, rngs: nnx.Rngs):
self.width = width
self.linear1 = nnx.Linear(din, width, use_bias = False, rngs=rngs)
self.linear2 = nnx.Linear(width, dout, use_bias = True, rngs=rngs)
def __call__(self, x: jax.Array):
x = self.linear1(x)
x = nnx.relu(x)
x = self.linear2(x)
return x
m = MyMLP(N, 1, 2, rngs = rngs)
f(m)
. Recall notation in Differentation lecturejax.grad
does, for scalar functionsnnx.grad
does this, recursively going through each nnx.Module
and its nnx.Param
valuesState({
'linear1': {
'bias': VariableState(
type=Param,
value=None
),
'kernel': VariableState(
type=Param,
value=Array([[-0.11148886, 0.66866976],
[-0.09731744, -0.486882 ],
[-0.9420541 , -0.13140532]], dtype=float32)
)
},
'linear2': {
'bias': VariableState(
type=Param,
value=Array([0.], dtype=float32)
),
'kernel': VariableState(
type=Param,
value=Array([[-0.1044913 ],
[-0.18319812]], dtype=float32)
)
}
})
graphdef
Contains Fixed Values and MetadataGraphDef(
nodedef=NodeDef(
type=MyMLP,
index=0,
attributes=('linear1', 'linear2', 'width'),
subgraphs={
'linear1': NodeDef(
type=Linear,
index=1,
attributes=('bias', 'bias_init', 'dot_general', 'dtype', 'in_features', 'kernel', 'kernel_init', 'out_features', 'param_dtype', 'precision', 'use_bias'),
subgraphs={
'dtype': NodeDef(
type=PytreeType,
index=-1,
attributes=(),
subgraphs={},
static_fields={},
leaves={},
metadata=PyTreeDef(None)
),
'precision': NodeDef(
type=PytreeType,
index=-1,
attributes=(),
subgraphs={},
static_fields={},
leaves={},
metadata=PyTreeDef(None)
)
},
static_fields={
'bias_init': <function zeros at 0x0000026AE86D0E00>,
'dot_general': <function dot_general at 0x0000026AE8187060>,
'in_features': 3,
'kernel_init': <function variance_scaling.<locals>.init at 0x0000026AF2993740>,
'out_features': 2,
'param_dtype': <class 'jax.numpy.float32'>,
'use_bias': False
},
leaves={
'bias': 2,
'kernel': 3
},
metadata=<class 'flax.nnx.nnx.nn.linear.Linear'>
),
'linear2': NodeDef(
type=Linear,
index=4,
attributes=('bias', 'bias_init', 'dot_general', 'dtype', 'in_features', 'kernel', 'kernel_init', 'out_features', 'param_dtype', 'precision', 'use_bias'),
subgraphs={
'dtype': NodeDef(
type=PytreeType,
index=-1,
attributes=(),
subgraphs={},
static_fields={},
leaves={},
metadata=PyTreeDef(None)
),
'precision': NodeDef(
type=PytreeType,
index=-1,
attributes=(),
subgraphs={},
static_fields={},
leaves={},
metadata=PyTreeDef(None)
)
},
static_fields={
'bias_init': <function zeros at 0x0000026AE86D0E00>,
'dot_general': <function dot_general at 0x0000026AE8187060>,
'in_features': 2,
'kernel_init': <function variance_scaling.<locals>.init at 0x0000026AF2993740>,
'out_features': 1,
'param_dtype': <class 'jax.numpy.float32'>,
'use_bias': True
},
leaves={
'bias': 5,
'kernel': 6
},
metadata=<class 'flax.nnx.nnx.nn.linear.Linear'>
)
},
static_fields={
'width': 2
},
leaves={},
metadata=<class '__main__.MyMLP'>
),
index_mapping=None
)
nnx.Module
State({
'linear1': {
'bias': VariableState(
type=Param,
value=None
),
'kernel': VariableState(
type=Param,
value=Array([[-0.784888 , -2.226523 ],
[-0.42862383, -1.2158941 ],
[-0.5630881 , -1.5973343 ]], dtype=float32)
)
},
'linear2': {
'bias': VariableState(
type=Param,
value=Array([1.], dtype=float32)
),
'kernel': VariableState(
type=Param,
value=Array([[0.97357047],
[1.0623739 ]], dtype=float32)
)
}
})
graphdef
m
by applying the graphdefstate
from before, and make a new type using the graphdef
eta = 0.01 # e.g., a gradient descent update
# jax.tree.map recursively goes through the model
# Updates the underlying nnx.Param given the delta_m grad
new_state = jax.tree.map(
lambda p, g: p - eta*g,
state, delta_m) # new_state = state - eta * delta_m
m_new = nnx.merge(graphdef, new_state)
f(m_new)
Array(0.84317195, dtype=float32)
nnx.jit, nnx.vmap, nnx.grad
will automatically split and merge for you (i.e., filtering) on nnx.Module
types as arguments, then call underlying JAX functions
nnx.grad
etc. would not work without modification since the NN combine differentiable and non-differentiable partsstate
and graphdef
and then merge them back togetherf_gen
functionstate
graphdef, state = nnx.split(m)
@jax.jit # note jax.jit instead of nnx.jit
def f_split(state): # closure on graphdef
m = nnx.merge(graphdef, state)
return f_gen(m, x, b)
# Can use jax.grad, rather than nnx.grad
state_diff = jax.grad(f_split)(state)
print(state_diff)
new_state = jax.tree.map(
lambda p, g: p - eta*g,
state, delta_m)
m_new = nnx.merge(graphdef, new_state)
f(m_new)
State({
'linear1': {
'bias': VariableState(
type=Param,
value=None
),
'kernel': VariableState(
type=Param,
value=Array([[-0.784888 , -2.226523 ],
[-0.42862383, -1.2158941 ],
[-0.5630881 , -1.5973343 ]], dtype=float32)
)
},
'linear2': {
'bias': VariableState(
type=Param,
value=Array([1.], dtype=float32)
),
'kernel': VariableState(
type=Param,
value=Array([[0.97357047],
[1.0623739 ]], dtype=float32)
)
}
})
Array(2.8807144, dtype=float32)
nnx.Param
in Pytorch is torch.nn.Parameter
class MyLinearTorch(nn.Module):
def __init__(self, in_size, out_size):
super(MyLinearTorch, self).__init__()
self.out_size = out_size
self.in_size = in_size
self.kernel = nn.Parameter(torch.randn(out_size, in_size))
# Similar to PyTorch's forward
def forward(self, x):
return self.kernel @ x
def f_gen_torch(m, x, b):
return torch.squeeze(m(x) + b)
Parameter containing:
tensor([[-0.6374, -0.7554, 0.5486]], requires_grad=True)
class MyMLPTorch(nn.Module):
def __init__(self, din, dout, width):
super(MyMLPTorch, self).__init__()
self.width = width
self.linear1 = nn.Linear(din, width, bias=False)
self.linear2 = nn.Linear(width, dout, bias=True)
def forward(self, x):
x = self.linear1(x)
x = torch.relu(x)
x = self.linear2(x)
return x
m = MyMLPTorch(N, 1, 2)
m.zero_grad()
output = f_torch(m)
# Start with d output = [1.0]
output.backward()
# Now `m` has the gradients
# Manually update parameters recursively
# Done in-place, as torch optimizers will do
with torch.no_grad():
# Recursively
for param in m.parameters():
param -= eta * param.grad
for name, param in m.named_parameters():
print(f"{name}: {param.numpy()}")
linear1.weight: [[ 0.340191 0.4308803 0.16260189]
[-0.07866785 -0.35180938 0.44592345]]
linear2.weight: [[ 0.2432586 -0.7069635]]
linear2.bias: [-0.41789627]
width
) of the representations, as well as tweaks to the optimizer and algorithmsGiven that you may want to solve your problem with a variety of different hyperparameters, possibly running in parallel, you need a convenient way to pass the values and see the results
One model, framework, OS independent way to do this is to use commandline arguments
For example, if you have a python file called mlp_regression_jax_nnx_logging.py
that accepts arguments for the width and learning rate, you may call it with
Many python frameworks exist to help you take CLI and convert to calling python functions. One convenient tool isjsonargparse
Advantage: simply annotate a function with defaults, and it will generate the CLI
Then can call python mlp_regression_jax_nnx_logging.py
, python mlp_regression_jax_nnx_logging.py --width=64
etc.
While the CLI file could save output for later interpretation, a common approach in ML is to log results to visualize the optimization process, compare results, etc.
One package for this is Weights and Biases
This will log into a website calculations, usually organized by a project name, and let you sort hyperparameters, etc.
To use this, setup an account and then add code to initialiize in your python file, then log intermediate results
Putting together the logging and the CLI, you can setup a process to run the code with a variety of parameters in a “sweep” (e.g., with sweep file)
wandb agent <sweep_id>
python mlp_regression_jax_nnx_logging.py --width=128 --lr=0.0015
etc.<sweep_id>
on multiple computers/processes/etc.lr
and width
to minimize test_loss
(if logged with wandb.log({"test_loss": test_loss})
, etc.)program: lectures/examples/mlp_regression_jax_nnx_logging.py
name: Sweep Example
method: bayes
metric:
name: test_loss
goal: minimize
parameters:
num_epochs:
value: 300 # fixed for all calls
lr: # uniformly distributed
min: 0.0001
max: 0.01
width: # discrete values to optimize over
values: [64, 128, 256]
nnx.Linear
and nnx.relu
layersclass MyMLP(nnx.Module):
def __init__(self, din: int, dout: int, width: int, *, rngs: nnx.Rngs):
self.linear1 = nnx.Linear(din, width, rngs=rngs)
self.linear2 = nnx.Linear(width, width, rngs=rngs)
self.linear3 = nnx.Linear(width, dout, rngs=rngs)
def __call__(self, x: jax.Array):
x = self.linear1(x)
x = nnx.relu(x)
x = self.linear2(x)
x = nnx.relu(x)
x = self.linear3(x)
return x
This randomly generates a \(\theta\) and then generates data with
ERM: with \(m \in \mathcal{H}\), minimize the residuals for batch (X, Y)
Optimizer uses loss differentiated wrt \(m\) as discussed
In order to use jsonargparse
, this creates a function signature with defaults
def fit_model(
N: int = 500,
M: int = 2,
sigma: float = 0.0001,
width: int = 128,
lr: float = 0.001,
num_epochs: int = 2000,
batch_size: int = 512,
seed: int = 42,
wandb_project: str = "econ622_examples",
wandb_mode: str = "offline", # "online", "disabled
):
# ... generate data, fit model, save test_loss
To run this sweep, you can run the following command (checking the relative location file)
The output should be something along the lines of
(econ622) C:\Users\jesse\Documents\GitHub\ECON622_instructor>wandb sweep lectures/examples/mlp_regression_jax_nnx_sweep.yaml
wandb: Creating sweep from: lectures/examples/mlp_regression_jax_nnx_sweep.yaml
wandb: Creating sweep with ID: virfdcn6
wandb: View sweep at: https://wandb.ai/highdimensionaleconlab/ECON622_instructor-lectures_examples/sweeps/virfdcn6
wandb: Run sweep agent with: wandb agent highdimensionaleconlab/ECON622_instructor-lectures_examples/virfdcn6
wandb agent highdimensionaleconlab/ECON622_instructor-lectures_examples/virfdcn6
and go to web to see results in progress