Dive deep(pun intended) into neural nets with some exciting learnings! Try to become a back prop ninja with Karpathy’s courses.

Derive Backprop Gradients step by step

 
dprobs = dlogprobs / probs
 
dcounts_sum_inv = (dprobs * counts).sum(1, keepdim=True)
 
dcounts = dprobs * counts_sum_inv
 
dcounts_sum = -dcounts_sum_inv * counts_sum ** -2
 
dcounts += dcounts_sum.broadcast_to(counts.shape)
 
dnorm_logits = dcounts * norm_logits.exp() # norm_logits.exp() is actually counts
 
dlogit_maxes = (-dnorm_logits).sum(1, keepdim=True)
 
dlogits = dnorm_logits.clone()
 
tmp = torch.zeros_like(logits)
 
tmp[range(n), logits.max(1, keepdim=True).indices.view(-1)] = 1 # try F.one_hot
 
dlogits += dlogit_maxes * tmp
 
dh = dlogits @ W2.T
 
dW2 = h.T @ dlogits
 
db2 = dlogits.sum(0, keepdim=False)
 
# dhpreact = dh * (1 - torch.tanh(hpreact) ** 2)
 
# dhpreact = (1.0 - h ** 2) * dh # figure out later
 
dhpreact = hpreact.grad.clone()
 
# dbngain = (dhpreact * bnraw).sum(0, keepdim=True)
 
dbngain = (dhpreact * bnraw).sum(0, keepdim=True)
 
dbnbias = dhpreact.sum(0, keepdim=True)
 
dbnraw = dhpreact * bngain
 
dbnvar_inv = (dbnraw * bndiff).sum(0, keepdim=True)
 
dbndiff = dbnraw * bnvar_inv
 
# dbnvar = dbnvar_inv * (-0.5) * bnvar_inv ** 3
 
dbnvar = dbnvar_inv * (-0.5) * (bnvar + 1e-5) ** -1.5
 
dbndiff2 = 1.0 / (n-1) * torch.ones_like(bndiff2) * dbnvar
 
dbndiff += 2 * bndiff * dbndiff2
 
dbnmeani = -dbndiff.sum(0, keepdim=True)
 
dhprebn = dbndiff.clone()
 
dhprebn += dbnmeani * 1.0 / n * torch.ones_like(hprebn)
 
dembcat = dhprebn @ W1.T
 
dW1 = embcat.T @ dhprebn
 
db1 = dhprebn.sum(0, keepdim=False)
 
demb = dembcat.view(emb.shape)
 
dC = torch.zeros_like(C)
 
for k in range(Xb.shape[0]):
 
for j in range(Xb.shape[1]):
 
ix = Xb[k,j]
 
dC[ix] += demb[k,j]
 

dlogprobs

Notation-wise, stands for the gradient of loss through . To start we have the following

which easily yields the gradient as follows.

 
dlogprobs = torch.zeros_like(logprobs)
 
dlogprobs[range(n), Yb] = -1. / n
 

dprobs

Continue the process, we have

Hence

 
dprobs = dlogprobs / probs
 

dcounts_sum_inv

 
dcounts_sum_inv = (dprobs * counts).sum(1, keepdim=True)
 

Note that probs = counts * counts_sum_inv is in fact

For simplicity of notation, denote counts_sum_inv by csi.

Beware of the broadcasting by checking the shapes.

 
>> counts.shape, counts_sum_inv.shape
 
(torch.Size([32, 27]), torch.Size([32, 1]))
 

Hence

and

 
dcounts_sum_inv = (dprobs * counts).sum(1, keepdim=True)
 

dcounts

Note that from one also has

 
probs = counts * counts_sum_inv
 

which leads to

However, other than contributing to loss through probs, counts also does that through counts_sum and then counts_sum_inv. The complete gradient of dcounts remains to be determined.

 
counts_sum_inv = counts_sum**-1
 

dcounts_sum

leads to

and then

dnorm_logits

 
dnorm_logits = dcounts * norm_logits.exp() # norm_logits.exp() is actually counts
 

dlogit_maxes

 
dlogit_maxes = (-dnorm_logits).sum(1, keepdim=True)
 

dlogits

 
dlogit_maxes = (-dnorm_logits).sum(1, keepdim=True)
 

Backprop through cross_entropy but All in One Go

Basically what happens in the forward pass can be described by the following pseudocode

 
logprobs = log(norm(softmax(logits, 1), 1))
 
loss = -mean(logprobs[range(n), Yb])
 

To discuss derivative for each single element, for every , denote by . For simplicity of notation, denote by .

Keep in mind that there is supposed to be a subtracting the maximum of logits (in each row) in the numerator. Here it is omitted because it does not affect the gradient towards loss.

Now conduct chain rules to derive the derivatives. If ,

which yields

If ,

which yields

While the two cases are discussed separately, they share a common part in softmax.

 
dlogits = F.softmax(logits, 1)
 
dlogits[range(n), Yb] -= 1
 
dlogits /= n
 

WIP

finish calculating the rest derivatives

Reference