Dive deep(pun intended) into neural nets with some exciting learnings! Try to become a back prop ninja with Karpathyβs courses .
Derive Backprop Gradients step by step
dprobs = dlogprobs / probs
dcounts_sum_inv = (dprobs * counts).sum( 1 , keepdim = True )
dcounts = dprobs * counts_sum_inv
dcounts_sum = - dcounts_sum_inv * counts_sum ** - 2
dcounts += dcounts_sum.broadcast_to(counts.shape)
dnorm_logits = dcounts * norm_logits.exp() # norm_logits.exp() is actually counts
dlogit_maxes = ( - dnorm_logits).sum( 1 , keepdim = True )
dlogits = dnorm_logits.clone()
tmp = torch.zeros_like(logits)
tmp[ range (n), logits.max( 1 , keepdim = True ).indices.view( - 1 )] = 1 # try F.one_hot
dlogits += dlogit_maxes * tmp
dh = dlogits @ W2.T
dW2 = h.T @ dlogits
db2 = dlogits.sum( 0 , keepdim = False )
# dhpreact = dh * (1 - torch.tanh(hpreact) ** 2)
# dhpreact = (1.0 - h ** 2) * dh # figure out later
dhpreact = hpreact.grad.clone()
# dbngain = (dhpreact * bnraw).sum(0, keepdim=True)
dbngain = (dhpreact * bnraw).sum( 0 , keepdim = True )
dbnbias = dhpreact.sum( 0 , keepdim = True )
dbnraw = dhpreact * bngain
dbnvar_inv = (dbnraw * bndiff).sum( 0 , keepdim = True )
dbndiff = dbnraw * bnvar_inv
# dbnvar = dbnvar_inv * (-0.5) * bnvar_inv ** 3
dbnvar = dbnvar_inv * ( - 0.5 ) * (bnvar + 1e-5 ) ** - 1.5
dbndiff2 = 1.0 / (n - 1 ) * torch.ones_like(bndiff2) * dbnvar
dbndiff += 2 * bndiff * dbndiff2
dbnmeani = - dbndiff.sum( 0 , keepdim = True )
dhprebn = dbndiff.clone()
dhprebn += dbnmeani * 1.0 / n * torch.ones_like(hprebn)
dembcat = dhprebn @ W1.T
dW1 = embcat.T @ dhprebn
db1 = dhprebn.sum( 0 , keepdim = False )
demb = dembcat.view(emb.shape)
dC = torch.zeros_like(C)
for k in range (Xb.shape[ 0 ]):
for j in range (Xb.shape[ 1 ]):
ix = Xb[k,j]
dC[ix] += demb[k,j]
dlogprobs
Notation-wise, d l o g p ro b s stands for the gradient of loss through l o g p ro b s . To start we have the following
l oss = β n 1 β i = 1 β β j β Y b β β β l o g p ro b s i , j β
which easily yields the gradient as follows.
( d l o g p ro b s d l oss β ) i , j β = β© β¨ β§ β β n 1 β , 0 , β j β Y b β otherwise β
dlogprobs = torch.zeros_like(logprobs)
dlogprobs[ range (n), Yb] = - 1 . / n
dprobs
Continue the process, we have
l o g p ro b s i , j β = log ( p ro b s i , j β )
Hence
( d p ro b s d l oss β ) i , j β = ( d l o g p ro b s d l oss β ) i , j β β
( d p ro b s d l o g p ro b s β ) i , j β = ( d l o g p ro b s d l oss β ) i , j β β
p ro b s i , j β 1 β
dprobs = dlogprobs / probs
dcounts_sum_inv
dcounts_sum_inv = (dprobs * counts).sum( 1 , keepdim = True )
Note that probs = counts * counts_sum_inv
is in fact
p ro b s i , j β = co u n t s i , j β β
cs i i β
For simplicity of notation, denote counts_sum_inv
by csi
.
Beware of the broadcasting by checking the shapes.
>> counts.shape, counts_sum_inv.shape
(torch.Size([ 32 , 27 ]), torch.Size([ 32 , 1 ]))
Hence
( d cs i d p ro b s β ) i β = j β β co u n t s i , j β
and
( d cs i d l oss β ) i β = j β β ( d p ro b s d l oss β ) i , j β β
( d cs i d p ro b s β ) i β
dcounts_sum_inv = (dprobs * counts).sum( 1 , keepdim = True )
dcounts
Note that from one also has
probs = counts * counts_sum_inv
which leads to
( d co u n t s d p ro b s β ) i , j β = cs i i β
However, other than contributing to loss through probs
, counts
also does that through counts_sum
and then counts_sum_inv
. The complete gradient of dcounts
remains to be determined.
counts_sum_inv = counts_sum **- 1
dcounts_sum
leads to
( d cs d cs i β ) i β = β c s i β 2 β
and then
( d co u n t s _ s u m d l oss β ) i β = β ( d cs i d l oss β ) i β β
co u n t s _ s u m i β 2 β
dnorm_logits
dnorm_logits = dcounts * norm_logits.exp() # norm_logits.exp() is actually counts
dlogit_maxes
dlogit_maxes = ( - dnorm_logits).sum( 1 , keepdim = True )
dlogits
dlogit_maxes = ( - dnorm_logits).sum( 1 , keepdim = True )
Backprop through cross_entropy
but All in One Go
Basically what happens in the forward pass can be described by the following pseudocode
logprobs = log(norm(softmax(logits, 1 ), 1 ))
loss = - mean(logprobs[ range (n), Yb])
To discuss derivative for each single element, for every i , denote ( Y b β ) i β by y . For simplicity of notation, denote l o g i t s by l g .
l oss β = β n 1 β i = 1 β n β l o g p ro b s i , y β = β n 1 β i = 1 β n β log ( p ro b s i , y β ) = β n 1 β i = 1 β n β log ( β k β exp { l g i , k β } exp { l g i , y β } β ) β
Keep in mind that there is supposed to be a subtracting the maximum of logits (in each row) in the numerator. Here it is omitted because it does not affect the gradient towards loss.
Now conduct chain rules to derive the derivatives. If j ξ = y ,
( d l g d l oss β ) i , j β = β n 1 β exp { l g i , y β } β k β exp { l g i , k β } β β
( β 1 ) β
( β k β exp { l g i , k β } ) 2 exp { l g i , y β } β β
exp { l g i , j β }
which yields
( d l g d l oss β ) i , j β = n 1 β β
β k exp { l g i , k β } exp { l g i , j β } β = n 1 β β
softmax ( l g i , β
β ) j β
If j = y ,
( d l g d l oss β ) i , j β = β n 1 β exp { l g i , y β } β k β exp { l g i , k β } β β
( β k β exp { l g i , k β } ) 2 exp { l g i , y β } β
β k β exp { l g i , k β } β exp { l g i , y β } 2 β
which yields
( d l g d l oss β ) i , j β = β n 1 β + n 1 β β
β k β exp { l g i , k β } exp { l g i , y β } β = n 1 β β
( softmax ( l g i , β
β ) y β ) β 1
While the two cases are discussed separately, they share a common part in softmax
.
dlogits = F.softmax(logits, 1 )
dlogits[ range (n), Yb] -= 1
dlogits /= n
finish calculating the rest derivatives
Reference