Dive deep(pun intended) into neural nets with some exciting learnings! Try to become a back prop ninja with Karpathy’s courses.

Derive Backprop Gradients step by step

dprobs = dlogprobs / probs
dcounts_sum_inv = (dprobs * counts).sum(1, keepdim=True)
dcounts = dprobs * counts_sum_inv
dcounts_sum = -dcounts_sum_inv * counts_sum ** -2
dcounts += dcounts_sum.broadcast_to(counts.shape)
dnorm_logits = dcounts * norm_logits.exp() # norm_logits.exp() is actually counts
dlogit_maxes = (-dnorm_logits).sum(1, keepdim=True)
dlogits = dnorm_logits.clone()
tmp = torch.zeros_like(logits)
tmp[range(n), logits.max(1, keepdim=True).indices.view(-1)] = 1 # try F.one_hot
dlogits += dlogit_maxes * tmp
dh = dlogits @ W2.T
dW2 = h.T @ dlogits
db2 = dlogits.sum(0, keepdim=False)
# dhpreact = dh * (1 - torch.tanh(hpreact) ** 2)
# dhpreact = (1.0 - h ** 2) * dh # figure out later
dhpreact = hpreact.grad.clone()
# dbngain = (dhpreact * bnraw).sum(0, keepdim=True)
dbngain = (dhpreact * bnraw).sum(0, keepdim=True)
dbnbias = dhpreact.sum(0, keepdim=True)
dbnraw = dhpreact * bngain
dbnvar_inv = (dbnraw * bndiff).sum(0, keepdim=True)
dbndiff = dbnraw * bnvar_inv
# dbnvar = dbnvar_inv * (-0.5) * bnvar_inv ** 3
dbnvar = dbnvar_inv * (-0.5) * (bnvar + 1e-5) ** -1.5
dbndiff2 = 1.0 / (n-1) * torch.ones_like(bndiff2) * dbnvar
dbndiff += 2 * bndiff * dbndiff2
dbnmeani = -dbndiff.sum(0, keepdim=True)
dhprebn = dbndiff.clone()
dhprebn += dbnmeani * 1.0 / n * torch.ones_like(hprebn)
dembcat = dhprebn @ W1.T
dW1 = embcat.T @ dhprebn
db1 = dhprebn.sum(0, keepdim=False)
demb = dembcat.view(emb.shape)
dC = torch.zeros_like(C)
for k in range(Xb.shape[0]):
for j in range(Xb.shape[1]):
ix = Xb[k,j]
dC[ix] += demb[k,j]


Notation-wise, stands for the gradient of loss through . To start we have the following

which easily yields the gradient as follows.

dlogprobs = torch.zeros_like(logprobs)
dlogprobs[range(n), Yb] = -1. / n


Continue the process, we have


dprobs = dlogprobs / probs


dcounts_sum_inv = (dprobs * counts).sum(1, keepdim=True)

Note that probs = counts * counts_sum_inv is in fact

For simplicity of notation, denote counts_sum_inv by csi.

Beware of the broadcasting by checking the shapes.

>> counts.shape, counts_sum_inv.shape
(torch.Size([32, 27]), torch.Size([32, 1]))



dcounts_sum_inv = (dprobs * counts).sum(1, keepdim=True)


Note that from one also has

probs = counts * counts_sum_inv

which leads to

However, other than contributing to loss through probs, counts also does that through counts_sum and then counts_sum_inv. The complete gradient of dcounts remains to be determined.

counts_sum_inv = counts_sum**-1


leads to

and then


dnorm_logits = dcounts * norm_logits.exp() # norm_logits.exp() is actually counts


dlogit_maxes = (-dnorm_logits).sum(1, keepdim=True)


dlogit_maxes = (-dnorm_logits).sum(1, keepdim=True)

Backprop through cross_entropy but All in One Go

Basically what happens in the forward pass can be described by the following pseudocode

logprobs = log(norm(softmax(logits, 1), 1))
loss = -mean(logprobs[range(n), Yb])

To discuss derivative for each single element, for every , denote by . For simplicity of notation, denote by .

Keep in mind that there is supposed to be a subtracting the maximum of logits (in each row) in the numerator. Here it is omitted because it does not affect the gradient towards loss.

Now conduct chain rules to derive the derivatives. If ,

which yields

If ,

which yields

While the two cases are discussed separately, they share a common part in softmax.

dlogits = F.softmax(logits, 1)
dlogits[range(n), Yb] -= 1
dlogits /= n


finish calculating the rest derivatives
