Idle Speculations on Pipeline Parallelism
2024 Jul 27
Epistimic status: Pure speculation; I have 1 GPU, not dozens, so I have no reason to try and implement this stuff myself.
When training large models, we might have many GPUs we'd like to use to speed the process up. This requires a way to parallelize training somehow. Some options here are Data Parallelism and Pipeline Parallelism.
If our model is very large, we might have trouble fitting it into the memory of a single GPU. Pipeline parallelism is convenient in such a case, because it treats the model as a pipeline, where each GPU is responsible for implementing a part of the pipeline. Eg. if the model consists of 64 identical blocks, then we could assign a GPU to each block. Then the activations coming out of each block are passed to the next block. For inference, these activations are the only data that must be communicated between GPUs, which makes pipeline parallelism a very bandwidthefficient optimization. In order to keep each GPU busy, we must feed in a continuous stream of data, adding more work before the work we previously added makes it through the pipeline.
Data parallelism is where we have a very large batch which we split up across multiple GPUs. Each GPU computes the output of the model on its own subbatch. We'll mostly focus on pipeline parallelism today, though both methods can be used at once.
Though it's great for inference, pipeline parallelism has some issues for training. Check this article for details, but I'll give a quick overview. The problem is, we have to compute backwards passes, and also weights are continually updated by the optimization algorithm. The need to compute backwards passes introduces a bubble into the system, and the size of the bubble increases with the length of the pipeline. See this figure from the DAPPLE paper:
What we can see is that the Dapple method has a much friendlier memory usage than the method originally put forward in GPipe. In hindsight, it's clear that we should try and do the backwards pass as soon as we can after the forwards pass completes, in order to free up all the saved activations. This does get complicated because the amount of work in a
backward()
is not necessarily the same as the amount of work in a forward()
. Let's neglect that issue for now though, and trust that the Dapple paper has provided good methods for dealing with it!removing the bubble
Dapple doesn't get rid of the bubble, just spreads it out. The origin of the bubble is the desire to cleanly compute a bunch of gradients, then do an optimization step, then repeat the same thing all over again. Each repetition introduces a new bubble. To make the bubble small relative to the middle section where all GPUs are being used, we have to accumulate gradients for a large number of pipeline passes. This means we're forced to multiply the batch size by many times the pipeline length to get a good efficiency. To get rid of the bubble entirely, check out the PipeMare paper. Along with each weight, this paper suggests that we store a velocity estimate for that weight, so that in the backward pass we can get a gradient that is approximately correct.
If you say that approximate correctness is not good enough and you'd like exact correctness, here's a scheme with a similar concept to PipeMare that solves the same problem. (This assumes a maybe unrealistically large amount of synchronization, so we might prefer a scheme like PipeMare in practice.) First, consider the following Dapplelike scheduling scheme (middle): We can see how it is made by interleaving forward passes (left) and backward passes (right).
This is with backwards passes, but without weight updates. Let's say we insert weight updates like so:
On the right, I've isolated a few forward/backward passes so we can more easily see the causal structure. The spacing of the weight updates is carefully chosen so that each forward/backward pass encloses one weight update. This means that we need to store the weights of the previous time step so that we can keep using them for forward/backward passes that started before the last update. But that's it. We can discard weights from two steps before. The upshot is that by the time we've computed the gradient for the weights at time t the weights are already at version t + 1. So we must use a gradient descent algorithm something like the following:
w
= w
 η
t + 1 
t 
∂𝓛  
∂w

where the gradients that are stale by one step.
This is not too bad of a problem. One can study gradientbased optimization methods by looking at their behaviour on a simple loss function like 𝓛 =
λ w
. By taking λ to be any complex number, we can then obtain a region of stability in the complex plane. (We have to find the spectral radius of the matrix representing how weights get updated in a full cycle of the optimization algorithm. If it's less than 1, we're in the region of stability. Full stability analysis code at the end of the article.)
1 
2 
2 
Region of stability for regular gradient descent:
Region of stability for gradient descent with stale gradients:
Comparing, we see that the region of stability gets smaller when gradients are stale. So we might have to reduce the learning rate by a factor of 2 or so. But the region of stability is still basically reasonable, it touches the origin, etc.
balancing memory usage by folding the pipeline
There's one more issue to consider: when we compute gradients we need to know the activations from the forward pass. This is the context that gets "saved for backward". If we look at the triangular shape of the forward and backward pass, we can see a bit of a problem: The early layers in the network have to store their activations for much longer than the later layers. Consider the following simple example, a forward and backward pass through 4 identical layers split across two GPUs (dashed line indicates division of work between the two devices).
The first device has to store its activations for much longer than the second device, even though it handles the same number of layers. If we consider a longer pipeline with many forward and backward passes being processed at once, the first device in the pipeline is clearly going to run out of memory long before the last one. Though it's hard to avoid the large memory requirements imposed by this, we can at least make sure they are balanced between GPUs. And the trick is a simple one. Just fold the pipeline so that each GPU handles an even number of layers, and the layers are paired together: first with last, second with second last, etc. This balances the memory usage nicely, with memoryhungry early layers being paired with memorylight late layers. This means that a forward pass travels from the first GPU to the last one and back, and a backward pass does the same, for a total of two round trips.
Since new data is continually being processed through the pipeline, there should not be too much imbalance in when the memory is needed.
secret sauce?
Interestingly, this stuff is actually fairly old. PipeMare is from 2020. Though I haven't heard of pipeline folding before, I'm sure it's long since been considered by the big AI labs, and is either in use, or discarded for some even better alternative. I'd guess a lot of these techniques are thought of as "secret sauce" by the big companies, and thus not published, since they provide a competitive advantage.
appendix: stability analysis code
import numpy as np
import matplotlib.pyplot as plt
β = 0.8
IMW, IMH = 400, 400
EXTENT = [0., 4., 2., 2.]
def make_matrix(shape_prefix, lst):
""" broadcast variables to make a matrix """
h, w = len(lst), len(lst[0])
ans = np.zeros(shape_prefix + (h, w), dtype=complex)
for i in range(h):
for j in range(w):
ans[..., i, j] = lst[i][j]
return ans
def A_sgd(λ):
lr = 1.
return make_matrix((IMW, IMH), [[1.  lr*λ]])
def A_polyak(λ):
lr = 1.  β # corrected LR for momentum method
return make_matrix((IMW, IMH), [
[1. + β  lr*λ, β],
[1., 0]])
def A_nesterov(λ):
lr = 1.  β # corrected LR for momentum method
Q = 1.  lr*λ
return make_matrix((IMW, IMH), [
[β*Q, Q  1.],
[β*Q, Q]])
def A_stale_sgd(λ):
lr = 1.
return make_matrix((IMW, IMH), [
[1., lr*λ],
[1., 0.]])
def get_max_eigenvalue_magnitude(mat):
vals = np.linalg.eigvals(mat)
return np.absolute(vals).max(1)
if __name__ == "__main__":
X, Y = np.meshgrid(np.linspace(EXTENT[0], EXTENT[1], IMW), np.linspace(EXTENT[2], EXTENT[3], IMH))
λ = X + 1j*Y
for A_fn in [A_sgd, A_stale_sgd]:
A_1 = A_fn(0.01)[0, 0] # specifically for λ = 0.1, what is the convergence rate?
print(get_max_eigenvalue_magnitude(A_1)) # did we correctly tune lr for momentum methods?
A = A_fn(λ)
color = np.minimum(1., get_max_eigenvalue_magnitude(A)**20)
plt.imshow(color, cmap="pink", extent=EXTENT)
plt.show()