Skip to content

Backpropagation

  • Build makemore Part 4: Becoming a Backprop Ninja

  • Derive Backprop Gradients step by step

    1
    dprobs = dlogprobs / probs
    2
    3
    dcounts_sum_inv = (dprobs * counts).sum(1, keepdim=True)
    4
    5
    dcounts = dprobs * counts_sum_inv
    6
    7
    dcounts_sum = -dcounts_sum_inv * counts_sum ** -2
    8
    9
    dcounts += dcounts_sum.broadcast_to(counts.shape)
    10
    11
    dnorm_logits = dcounts * norm_logits.exp() # norm_logits.exp() is actually counts
    12
    13
    dlogit_maxes = (-dnorm_logits).sum(1, keepdim=True)
    14
    15
    dlogits = dnorm_logits.clone()
    16
    17
    tmp = torch.zeros_like(logits)
    18
    19
    tmp[range(n), logits.max(1, keepdim=True).indices.view(-1)] = 1 # try F.one_hot
    20
    21
    dlogits += dlogit_maxes * tmp
    22
    23
    dh = dlogits @ W2.T
    24
    25
    dW2 = h.T @ dlogits
    26
    27
    db2 = dlogits.sum(0, keepdim=False)
    28
    29
    # dhpreact = dh * (1 - torch.tanh(hpreact) ** 2)
    30
    31
    # dhpreact = (1.0 - h ** 2) * dh # figure out later
    32
    33
    dhpreact = hpreact.grad.clone()
    34
    35
    # dbngain = (dhpreact * bnraw).sum(0, keepdim=True)
    36
    37
    dbngain = (dhpreact * bnraw).sum(0, keepdim=True)
    38
    39
    dbnbias = dhpreact.sum(0, keepdim=True)
    40
    41
    dbnraw = dhpreact * bngain
    42
    43
    dbnvar_inv = (dbnraw * bndiff).sum(0, keepdim=True)
    44
    45
    dbndiff = dbnraw * bnvar_inv
    46
    47
    # dbnvar = dbnvar_inv * (-0.5) * bnvar_inv ** 3
    48
    49
    dbnvar = dbnvar_inv * (-0.5) * (bnvar + 1e-5) ** -1.5
    50
    51
    dbndiff2 = 1.0 / (n-1) * torch.ones_like(bndiff2) * dbnvar
    52
    53
    dbndiff += 2 * bndiff * dbndiff2
    54
    55
    dbnmeani = -dbndiff.sum(0, keepdim=True)
    56
    57
    dhprebn = dbndiff.clone()
    58
    59
    dhprebn += dbnmeani * 1.0 / n * torch.ones_like(hprebn)
    60
    61
    dembcat = dhprebn @ W1.T
    62
    63
    dW1 = embcat.T @ dhprebn
    64
    65
    db1 = dhprebn.sum(0, keepdim=False)
    66
    67
    demb = dembcat.view(emb.shape)
    68
    69
    dC = torch.zeros_like(C)
    70
    71
    for k in range(Xb.shape[0]):
    72
    73
    for j in range(Xb.shape[1]):
    74
    75
    ix = Xb[k,j]
    76
    77
    dC[ix] += demb[k,j]

    dlogprobs

    Notation-wise, dlogprobsdlogprobs stands for the gradient of loss through logprobslogprobs. To start we have the following

    loss=1ni=1jYblogprobsi,jloss = -\frac{1}{n}\sum_{i=1}\sum_{j\in Y_{b}}logprobs_{i, j}

    which easily yields the gradient as follows.

    (dlossdlogprobs)i,j={1n,jYb0,otherwise\left(\frac{dloss}{dlogprobs}\right)_{i,j}=\left\{ \begin{aligned} -\frac{1}{n},\quad&j\in Y_{b} \\ 0,\quad&\text{otherwise} \end{aligned} \right.
    1
    dlogprobs = torch.zeros_like(logprobs)
    2
    dlogprobs[range(n), Yb] = -1. / n

    dprobs

    Continue the process, we have

    logprobsi,j=log(probsi,j)logprobs_{i,j}=\log(probs_{i,j})

    Hence

    (dlossdprobs)i,j=(dlossdlogprobs)i,j(dlogprobsdprobs)i,j=(dlossdlogprobs)i,j1probsi,j\left(\frac{dloss}{dprobs}\right)_{i,j}=\left(\frac{dloss}{dlogprobs}\right)_{i,j}\cdot\left(\frac{dlogprobs}{dprobs}\right)_{i,j}=\left(\frac{dloss}{dlogprobs}\right)_{i,j}\cdot\frac{1}{probs_{i,j}}
    1
    dprobs = dlogprobs / probs

    dcounts_sum_inv

    1
    dcounts_sum_inv = (dprobs * counts).sum(1, keepdim=True)

    Note that probs = counts * counts_sum_inv is in fact

    probsi,j=countsi,jcsiiprobs_{i,j}=counts_{i,j}\cdot csi_{i}

    For simplicity of notation, denote counts_sum_inv by csi.

    1
    >>> counts.shape, counts_sum_inv.shape
    2
    (torch.Size([32, 27]), torch.Size([32, 1]))

    Hence

    (dprobsdcsi)i=jcountsi,j\left(\frac{dprobs}{dcsi}\right)_{i}=\sum_{j}counts_{i,j}

    and

    (dlossdcsi)i=j(dlossdprobs)i,j(dprobsdcsi)i\left(\frac{dloss}{dcsi}\right)_{i}=\sum_{j}\left(\frac{dloss}{dprobs}\right)_{i,j}\cdot\left(\frac{dprobs}{dcsi}\right)_{i}
    1
    dcounts_sum_inv = (dprobs * counts).sum(1, keepdim=True)

    dcounts

    Note that from one also has

    1
    probs = counts * counts_sum_inv

    which leads to

    (dprobsdcounts)i,j=csii\left(\frac{dprobs}{dcounts}\right)_{i,j}=csi_{i}

    However, other than contributing to loss through probs, counts also does that through counts_sum and then counts_sum_inv. The complete gradient of dcounts remains to be determined.

    1
    counts_sum_inv = counts_sum**-1

    dcounts_sum

    leads to

    (dcsidcs)i=csi2\left(\frac{dcsi}{dcs}\right)_{i}=-cs_{i}^{-2}

    and then

    (dlossdcounts_sum)i=(dlossdcsi)icounts_sumi2\left(\frac{dloss}{dcounts\_sum}\right)_{i}=-\left(\frac{dloss}{dcsi}\right)_{i}\cdot counts\_sum_{i}^{-2}

    dnorm_logits

    1
    dnorm_logits = dcounts * norm_logits.exp() # norm_logits.exp() is actually counts

    dlogit_maxes

    1
    dlogit_maxes = (-dnorm_logits).sum(1, keepdim=True)

    dlogits

    1
    dlogit_maxes = (-dnorm_logits).sum(1, keepdim=True)

    Backprop through cross_entropy but All in One Go

    Basically what happens in the forward pass can be described by the following pseudocode

    1
    logprobs = log(norm(softmax(logits, 1), 1))
    2
    loss = -mean(logprobs[range(n), Yb])

    To discuss derivative for each single element, for every ii, denote (Yb)i(Y_{b})_{i} by yy. For simplicity of notation, denote logitslogits by lglg.

    loss=1ni=1nlogprobsi,y=1ni=1nlog(probsi,y)=1ni=1nlog(exp{lgi,y}kexp{lgi,k})\begin{aligned} loss&=-\frac{1}{n}\sum_{i=1}^{n}logprobs_{i,y}\\ &=-\frac{1}{n}\sum_{i=1}^{n}\log(probs_{i,y})\\ &=-\frac{1}{n}\sum_{i=1}^{n}\log\left(\frac{\exp\{lg_{i,y}\}}{\sum_{k}\exp\{lg_{i,k}\}}\right) \end{aligned}

    Now conduct chain rules to derive the derivatives. If jyj\neq y,

    (dlossdlg)i,j=1nkexp{lgi,k}exp{lgi,y}(1)exp{lgi,y}(kexp{lgi,k})2exp{lgi,j}\left(\frac{dloss}{dlg}\right)_{i,j}=-\frac{1}{n}\frac{\sum_{k}\exp\{lg_{i,k}\}}{\exp\{lg_{i,y}\}}\cdot(-1)\cdot\frac{\exp\{lg_{i,y}\}}{(\sum_{k}\exp\{lg_{i,k}\})^{2}}\cdot\exp\{lg_{i,j}\}

    which yields

    (dlossdlg)i,j=1nexp{lgi,j}kexp{lgi,k}=1nsoftmax(lgi,)j\left(\frac{dloss}{dlg}\right)_{i,j}=\frac{1}{n}\cdot\frac{\exp\{lg_{i,j}\}}{\sum{k}\exp\{lg_{i,k}\}}=\frac{1}{n}\cdot\text{softmax}(lg_{i,\cdot})_{j}

    If j=yj=y,

    (dlossdlg)i,j=1nkexp{lgi,k}exp{lgi,y}exp{lgi,y}kexp{lgi,k}exp{lgi,y}2(kexp{lgi,k})2\left(\frac{dloss}{dlg}\right)_{i,j}=-\frac{1}{n}\frac{\sum_{k}\exp\{lg_{i,k}\}}{\exp\{lg_{i,y}\}}\cdot\frac{\exp\{lg_{i,y}\}\cdot\sum_{k}\exp\{lg_{i,k}\}-\exp\{lg_{i,y}\}^{2}}{(\sum_{k}\exp\{lg_{i,k}\})^{2}}

    which yields

    (dlossdlg)i,j=1n+1nexp{lgi,y}kexp{lgi,k}=1n(softmax(lgi,)y)1\left(\frac{dloss}{dlg}\right)_{i,j}=-\frac{1}{n}+\frac{1}{n}\cdot\frac{\exp\{lg_{i,y}\}}{\sum_{k}\exp\{lg_{i,k}\}}=\frac{1}{n}\cdot(\text{softmax}(lg_{i,\cdot})_{y})-1

    While the two cases are discussed separately, they share a common part in softmax.

    1
    dlogits = F.softmax(logits, 1)
    2
    dlogits[range(n), Yb] -= 1
    3
    dlogits /= n

    Finish calculating the rest derivatives. Read more

    A typical training loop looks like this. One thing worth mentioning is the line optimizer.zero_grad(). Gradients in PyTorch are accumulated by default. Without zeroing them, gradients from multiple backpropagation steps would add up, leading to incorrect model updates. In other words, zeroing gradients ensures that each backward pass calculates gradients based only on the current mini-batch.

    1
    # Training loop
    2
    losses = []
    3
    for epoch in range(num_epochs):
    4
    epoch_loss = 0
    5
    for batch_X, batch_y in dataloader:
    6
    # Forward pass
    7
    outputs = model(batch_X)
    8
    loss = criterion(outputs, batch_y)
    9
    10
    # Backward pass and optimize
    11
    12
    optimizer.zero_grad()
    13
    loss.backward()
    14
    optimizer.step()
    15
    16
    17
    epoch_loss += loss.item()
    18
    19
    # Record average loss for the epoch
    20
    avg_loss = epoch_loss / len(dataloader)
    21
    losses.append(avg_loss)

    These steps are part of each training loop iteration and help in adjusting the model’s parameters to improve its predictions.

    1. optimizer.zero_grad():

      • Purpose: Clears old gradients.
      • What Happens: Resets the gradients of all model parameters to zero. This prevents the accumulation of gradients from multiple backpropagations, ensuring each mini-batch is computed independently.
    2. loss.backward():

      • Purpose: Computes gradients.
      • What Happens: Performs backpropagation to compute the gradient of the loss with respect to each parameter (weight and bias). This is where the network learns, by calculating how much each parameter needs to change to minimize the loss.
    3. optimizer.step():

      • Purpose: Updates parameters.
      • What Happens: Updates the model parameters based on the calculated gradients. This step uses the gradients computed during loss.backward() to adjust the weights in an attempt to minimize the loss.