Training Examples are Vector Fields and their Lie Brackets can be Computed

skip to results

An ideal machine learning model would not care what order training examples appeared in its training process. From a Bayesian perspective, the training dataset is unordered data and all updates based on seeing one additional example should commute with each other. For neural nets trained by gradient descent, however, this is not the case. This webpage will explain how to compute the effects of swapping the order of two training examples on a per-parameter level, and show the results of computing these quantities for a simple convnet model.

To get started, we just need to recognize one simple mathematical fact:

Training Examples are Vector Fields

If we are training a neural network with parameters $\theta \in \Theta = \mathbb{R}^\text{num params}$, then we can treat each training example as a vector field. In particular, if $x$ is a training example and $\mathcal{L}^{(x)}$ is the per-example loss for the training example $x$, then this vector field is:

$$ v^{(x)}(\theta) = -\nabla_{\theta} \mathcal{L}^{(x)} $$

In other words, for a specific training example, the arrows of the resulting vector field point in the direction that the parameters should be updated.

In this view, a gradient update basically looks like moving in the direction of the vector field by the learning rate $\epsilon$.

$$ \theta' = \theta + \epsilon v^{(x)}(\theta). $$

The Training Example Lie Bracket

One thing we can do with vector fields is to compute their Lie bracket. So if $x, y$ are training examples, we may compute:

$$ [v^{(x)}, v^{(y)}] = (v^{(x)}\cdot \nabla_\theta) v^{(y)} - (v^{(y)}\cdot \nabla_\theta) v^{(x)} $$

We can compute the Lie bracket of any two vector fields on $\Theta$, and so we can certainly compute the Lie bracket of the vector fields arising from two training examples. The Lie bracket of two training examples tells us about the order dependence of training on those examples. The Lie bracket of a vector field is itself a vector field, and so just like a gradient, we get a Lie bracket tensor for each parameter tensor of the same shape as that parameter tensor.

The Lie Bracket Tells us About Order-Dependence

We can interpret this quantity as the difference between updating on $x$ before $y$ vs after. Let's Taylor expand to see this. If $\epsilon$ is the learning rate, we'll want to expand to $O(\epsilon^2)$:

$$\theta' = \theta + \epsilon v^{(x)}(\theta)$$ $$ \theta'' = \theta' + \epsilon v^{(y)}(\theta') $$ $$= \theta + \epsilon v^{(x)}(\theta) + \epsilon v^{(y)}(\theta) + \epsilon^2 (v^{(x)}(\theta) \cdot \nabla_\theta) v^{(y)}(\theta)$$

Now if we update $x,y$ in the other order, we get an $O(\epsilon^2)$ difference in the resulting parameters $\theta''$. Namely:

$$ \Delta \theta'' = \epsilon^2 \left( (v^{(x)}(\theta) \cdot \nabla_\theta) v^{(y)}(\theta) - (v^{(y)}(\theta) \cdot \nabla_\theta) v^{(x)}(\theta) \right) $$ $$ \Delta \theta'' = \epsilon^2 [v^{(x)}, v^{(y)}] (\theta) $$

So here we can see the significance of the Lie bracket: It tells us the difference in where our parameters end up based on which order we show the training examples in.

Note that by the linearity of the Lie bracket, swapping the order of two minibatches has an effect given by averaging over all pairs of examples.

Prior Work

When searching the literature for work on the Lie brackets of training examples, the earliest description we found was Dherin in 2023, who connects the bracket's ability to measure commutativity of updates to implicit biases in neural net training.

We go farther here by explicitly computing the bracket value at various checkpoints in the training of an actual convnet.

Experiment Details

We replicate the MXResNet architecture (without attention layers) and train it on the CelebA dataset for 5000 steps at a batch size of 32, saving weight checkpoints from time to time. The optimizer is Adam, with the following parameters:


lr = 5e-3
betas = (0.8, 0.999)

The CelebA dataset has 40 binary attributes (such as Male or Black_Hair) and the neural net is tasked with predicting each of these independently and simultaneously (averaged binary classification loss).

We evaluated each checkpoint of the model on a batch of 32 examples from the test set. We computed Lie brackets between only the first 6 of these test examples to limit disk space usage, as each individual Lie bracket has the same size as a full checkpoint of the model. For each of these brackets representing a swap of two examples, we show how all 40 logits for all 32 test examples in the batch are perturbed when the two examples are swapped.

Results

We have some things to say about the results, but first try exploring them yourself! The slider controls which checkpoint from the training process we're examining, and you can click on the buttons to see data about particular Lie brackets. $[u_i, u_j] = -[u_j, u_i]$ so brackets across the diagonal from each other are just negatives of each other.

Loss History

Slider: Select a checkpoint.

Lie Brackets of Example Pairs

Click any off-diagonal pair to inspect bracket $[u_i, u_j]$ details for examples $i$ vs $j$. Note that it varies by which checkpoint you've selected. Heatmaps show how much logits for a specific input and feature change when the two selected examples are swapped.

Patterns

Lie Bracket Magnitudes are Proportional to Gradient Magnitudes

If we look at the tensors that the Lie bracket provides for each parameter, the RMS magnitudes of these tensors vary widely over many orders of magnitude (just like the gradients for these tensors do). But, if we plot RMS magnitudes against RMS gradients for each parameter tensor on a log-log scale, we find that there is a remarkably tight correlation between the two. Indeed, it appears that for each bracket, it's possible to fit a simple 1-parameter line that pins down a constant of proportionality between bracket magnitudes and gradient magnitudes.

This is interesting, because it suggests that the magnitude of a given Lie bracket is mostly determined by:

1. The $v^{(x)}$ part of $(v^{(x)}\cdot \nabla_\theta) v^{(y)}$ (and similarly for the other term).

2. Factors that are independent of which parameter tensor we're looking at such as how far along in the training we are and how "intrinsically non-commutative" the training examples we're taking the bracket of are.

In other words, it suggests that for a given bracket, the $\nabla_\theta v^{(y)}$ part of $(v^{(x)}\cdot \nabla_\theta) v^{(y)}$ has a relatively constant magnitude across all the parameters of the net (and similarly for the other term).

Non-Commutativity's Effect on Predictions may be able to Flag Modelling Issues

You may have noticed that past checkpoint 600 or so, the Black_Hair and Brown_Hair logits tended to have large deltas under most of the Lie brackets in the matrix. This means that predictions for these two features varied a lot based on example ordering. We have a hypothesis for why this may be the case:

Consider that in the dataset, black hair and brown hair are not simultaneously present. The other 3 combinations of these features are present, but a value of "True" for both at once should not occur. However, the model outputs separate predictions for each feature, and the resulting distribution can only be a product of the individual predictions. In other words, the loss function implicitly assumes that the model's predictions must be independent of each other. If the model is unsure whether a person's photo has black or brown hair, (which can be quite a common case, depending on lighting), then it would predict a 50% chance for each feature. The loss function interprets this as a 25% chance of any of the 4 combinations, but what the model would probably like to predict is a 50:50 split between (False, True) and (True, False).

Our hypothesis is that this inadequacy in the loss function is what is driving the logit deltas for these features to tend to be larger compared to the others.

Appendix: Lie Bracket Computation Source Code (pytorch)

Project GitHub Repo

Specific functions to compute a Lie bracket:


def dotgrad(v, u_x, θ):
  """ Compute (v⋅∇)u_x(θ) using forward mode autodiff. """
  _, ans = torch.func.jvp(u_x, (θ,), (v,))
  ans = params_dict_detach(ans)
  return ans

def lie_brackets(u, x, θ, chunk_size=None):
  """ Compute a Lie bracket for the update field u_x(θ). x is expected to have
      a batch dim, num_samples. Brackets are computed between all elements of the batch.
      Output is a dict like θ: {"name": (*shape)...} except shaped like this:
      {"name": (num_samples, num_samples, *shape)...} """
  v = torch.func.vmap(
    (lambda x, θ: params_dict_detach(u(x)(θ))),
    in_dims=(0, None)
  )(x, θ)
  v_dotgrad = torch.func.vmap( # map over vector multiplicity
    dotgrad,
    in_dims=(0, None, None),
  )
  Jv = torch.func.vmap(
    (lambda v, x, θ: v_dotgrad(v, u(x), θ)), # map over x multiplicity
    in_dims=(None, 0, None),
    chunk_size=chunk_size
  )(v, x, θ)
  return {
    # based on order of vmaps, dim 1 tells us the direction of the derivative, dim 0 is x multiplicity
    key: Jv[key].transpose(0, 1) - Jv[key]
    for key in Jv
  }