dprobs = dlogprobs / probs
dcounts_sum_inv = ( dprobs * counts ). sum ( 1 , keepdim = True )
dcounts = dprobs * counts_sum_inv
dcounts_sum = - dcounts_sum_inv * counts_sum ** - 2
dcounts += dcounts_sum . broadcast_to ( counts . shape )
dnorm_logits = dcounts * norm_logits . exp () # norm_logits.exp() is actually counts
dlogit_maxes = ( - dnorm_logits ). sum ( 1 , keepdim = True )
dlogits = dnorm_logits . clone ()
tmp = torch . zeros_like ( logits )
tmp [ range ( n ), logits . max ( 1 , keepdim = True ). indices . view ( - 1 )] = 1 # try F.one_hot
dlogits += dlogit_maxes * tmp
db2 = dlogits . sum ( 0 , keepdim = False )
# dhpreact = dh * (1 - torch.tanh(hpreact) ** 2)
# dhpreact = (1.0 - h ** 2) * dh # figure out later
dhpreact = hpreact . grad . clone ()
# dbngain = (dhpreact * bnraw).sum(0, keepdim=True)
dbngain = ( dhpreact * bnraw ). sum ( 0 , keepdim = True )
dbnbias = dhpreact . sum ( 0 , keepdim = True )
dbnraw = dhpreact * bngain
dbnvar_inv = ( dbnraw * bndiff ). sum ( 0 , keepdim = True )
dbndiff = dbnraw * bnvar_inv
# dbnvar = dbnvar_inv * (-0.5) * bnvar_inv ** 3
dbnvar = dbnvar_inv * ( - 0.5 ) * ( bnvar + 1e-5 ) ** - 1.5
dbndiff2 = 1.0 / ( n - 1 ) * torch . ones_like ( bndiff2 ) * dbnvar
dbndiff += 2 * bndiff * dbndiff2
dbnmeani = - dbndiff . sum ( 0 , keepdim = True )
dhprebn = dbndiff . clone ()
dhprebn += dbnmeani * 1.0 / n * torch . ones_like ( hprebn )
db1 = dhprebn . sum ( 0 , keepdim = False )
demb = dembcat . view ( emb . shape )
for k in range ( Xb . shape [ 0 ]):
for j in range ( Xb . shape [ 1 ]):
Notation-wise, d l o g p r o b s dlogprobs d l o g p ro b s stands for the gradient of loss through l o g p r o b s logprobs l o g p ro b s . To start we have the following
l o s s = − 1 n ∑ i = 1 ∑ j ∈ Y b l o g p r o b s i , j loss = -\frac{1}{n}\sum_{i=1}\sum_{j\in Y_{b}}logprobs_{i, j} l oss = − n 1 i = 1 ∑ j ∈ Y b ∑ l o g p ro b s i , j
which easily yields the gradient as follows.
( d l o s s d l o g p r o b s ) i , j = { − 1 n , j ∈ Y b 0 , otherwise \left(\frac{dloss}{dlogprobs}\right)_{i,j}=\left\{
\begin{aligned}
-\frac{1}{n},\quad&j\in Y_{b} \\
0,\quad&\text{otherwise}
\end{aligned}
\right. ( d l o g p ro b s d l oss ) i , j = ⎩ ⎨ ⎧ − n 1 , 0 , j ∈ Y b otherwise
dlogprobs = torch . zeros_like ( logprobs )
dlogprobs [ range ( n ), Yb ] = - 1 . / n
Continue the process, we have
l o g p r o b s i , j = log ( p r o b s i , j ) logprobs_{i,j}=\log(probs_{i,j}) l o g p ro b s i , j = log ( p ro b s i , j )
Hence
( d l o s s d p r o b s ) i , j = ( d l o s s d l o g p r o b s ) i , j ⋅ ( d l o g p r o b s d p r o b s ) i , j = ( d l o s s d l o g p r o b s ) i , j ⋅ 1 p r o b s i , j \left(\frac{dloss}{dprobs}\right)_{i,j}=\left(\frac{dloss}{dlogprobs}\right)_{i,j}\cdot\left(\frac{dlogprobs}{dprobs}\right)_{i,j}=\left(\frac{dloss}{dlogprobs}\right)_{i,j}\cdot\frac{1}{probs_{i,j}} ( d p ro b s d l oss ) i , j = ( d l o g p ro b s d l oss ) i , j ⋅ ( d p ro b s d l o g p ro b s ) i , j = ( d l o g p ro b s d l oss ) i , j ⋅ p ro b s i , j 1
dprobs = dlogprobs / probs
dcounts_sum_inv = ( dprobs * counts ). sum ( 1 , keepdim = True )
Note that probs = counts * counts_sum_inv
is in fact
p r o b s i , j = c o u n t s i , j ⋅ c s i i probs_{i,j}=counts_{i,j}\cdot csi_{i} p ro b s i , j = co u n t s i , j ⋅ cs i i
For simplicity of notation, denote counts_sum_inv
by csi
.
>>> counts . shape , counts_sum_inv . shape
( torch . Size ([ 32 , 27 ]), torch . Size ([ 32 , 1 ]))
Hence
( d p r o b s d c s i ) i = ∑ j c o u n t s i , j \left(\frac{dprobs}{dcsi}\right)_{i}=\sum_{j}counts_{i,j} ( d cs i d p ro b s ) i = j ∑ co u n t s i , j
and
( d l o s s d c s i ) i = ∑ j ( d l o s s d p r o b s ) i , j ⋅ ( d p r o b s d c s i ) i \left(\frac{dloss}{dcsi}\right)_{i}=\sum_{j}\left(\frac{dloss}{dprobs}\right)_{i,j}\cdot\left(\frac{dprobs}{dcsi}\right)_{i} ( d cs i d l oss ) i = j ∑ ( d p ro b s d l oss ) i , j ⋅ ( d cs i d p ro b s ) i
dcounts_sum_inv = ( dprobs * counts ). sum ( 1 , keepdim = True )
Note that from one also has
probs = counts * counts_sum_inv
which leads to
( d p r o b s d c o u n t s ) i , j = c s i i \left(\frac{dprobs}{dcounts}\right)_{i,j}=csi_{i} ( d co u n t s d p ro b s ) i , j = cs i i
However, other than contributing to loss through probs
, counts
also does that through counts_sum
and then counts_sum_inv
. The complete gradient of dcounts
remains to be determined.
counts_sum_inv = counts_sum **- 1
leads to
( d c s i d c s ) i = − c s i − 2 \left(\frac{dcsi}{dcs}\right)_{i}=-cs_{i}^{-2} ( d cs d cs i ) i = − c s i − 2
and then
( d l o s s d c o u n t s _ s u m ) i = − ( d l o s s d c s i ) i ⋅ c o u n t s _ s u m i − 2 \left(\frac{dloss}{dcounts\_sum}\right)_{i}=-\left(\frac{dloss}{dcsi}\right)_{i}\cdot counts\_sum_{i}^{-2} ( d co u n t s _ s u m d l oss ) i = − ( d cs i d l oss ) i ⋅ co u n t s _ s u m i − 2
dnorm_logits = dcounts * norm_logits . exp () # norm_logits.exp() is actually counts
dlogit_maxes = ( - dnorm_logits ). sum ( 1 , keepdim = True )
dlogit_maxes = ( - dnorm_logits ). sum ( 1 , keepdim = True )
Basically what happens in the forward pass can be described by the following pseudocode
logprobs = log ( norm ( softmax ( logits , 1 ), 1 ))
loss = - mean ( logprobs [ range ( n ), Yb ])
To discuss derivative for each single element, for every i i i , denote ( Y b ) i (Y_{b})_{i} ( Y b ) i by y y y . For simplicity of notation, denote l o g i t s logits l o g i t s by l g lg l g .
l o s s = − 1 n ∑ i = 1 n l o g p r o b s i , y = − 1 n ∑ i = 1 n log ( p r o b s i , y ) = − 1 n ∑ i = 1 n log ( exp { l g i , y } ∑ k exp { l g i , k } ) \begin{aligned}
loss&=-\frac{1}{n}\sum_{i=1}^{n}logprobs_{i,y}\\
&=-\frac{1}{n}\sum_{i=1}^{n}\log(probs_{i,y})\\
&=-\frac{1}{n}\sum_{i=1}^{n}\log\left(\frac{\exp\{lg_{i,y}\}}{\sum_{k}\exp\{lg_{i,k}\}}\right)
\end{aligned} l oss = − n 1 i = 1 ∑ n l o g p ro b s i , y = − n 1 i = 1 ∑ n log ( p ro b s i , y ) = − n 1 i = 1 ∑ n log ( ∑ k exp { l g i , k } exp { l g i , y } )
Now conduct chain rules to derive the derivatives. If j ≠ y j\neq y j = y ,
( d l o s s d l g ) i , j = − 1 n ∑ k exp { l g i , k } exp { l g i , y } ⋅ ( − 1 ) ⋅ exp { l g i , y } ( ∑ k exp { l g i , k } ) 2 ⋅ exp { l g i , j } \left(\frac{dloss}{dlg}\right)_{i,j}=-\frac{1}{n}\frac{\sum_{k}\exp\{lg_{i,k}\}}{\exp\{lg_{i,y}\}}\cdot(-1)\cdot\frac{\exp\{lg_{i,y}\}}{(\sum_{k}\exp\{lg_{i,k}\})^{2}}\cdot\exp\{lg_{i,j}\} ( d l g d l oss ) i , j = − n 1 exp { l g i , y } ∑ k exp { l g i , k } ⋅ ( − 1 ) ⋅ ( ∑ k exp { l g i , k } ) 2 exp { l g i , y } ⋅ exp { l g i , j }
which yields
( d l o s s d l g ) i , j = 1 n ⋅ exp { l g i , j } ∑ k exp { l g i , k } = 1 n ⋅ softmax ( l g i , ⋅ ) j \left(\frac{dloss}{dlg}\right)_{i,j}=\frac{1}{n}\cdot\frac{\exp\{lg_{i,j}\}}{\sum{k}\exp\{lg_{i,k}\}}=\frac{1}{n}\cdot\text{softmax}(lg_{i,\cdot})_{j} ( d l g d l oss ) i , j = n 1 ⋅ ∑ k exp { l g i , k } exp { l g i , j } = n 1 ⋅ softmax ( l g i , ⋅ ) j
If j = y j=y j = y ,
( d l o s s d l g ) i , j = − 1 n ∑ k exp { l g i , k } exp { l g i , y } ⋅ exp { l g i , y } ⋅ ∑ k exp { l g i , k } − exp { l g i , y } 2 ( ∑ k exp { l g i , k } ) 2 \left(\frac{dloss}{dlg}\right)_{i,j}=-\frac{1}{n}\frac{\sum_{k}\exp\{lg_{i,k}\}}{\exp\{lg_{i,y}\}}\cdot\frac{\exp\{lg_{i,y}\}\cdot\sum_{k}\exp\{lg_{i,k}\}-\exp\{lg_{i,y}\}^{2}}{(\sum_{k}\exp\{lg_{i,k}\})^{2}} ( d l g d l oss ) i , j = − n 1 exp { l g i , y } ∑ k exp { l g i , k } ⋅ ( ∑ k exp { l g i , k } ) 2 exp { l g i , y } ⋅ ∑ k exp { l g i , k } − exp { l g i , y } 2
which yields
( d l o s s d l g ) i , j = − 1 n + 1 n ⋅ exp { l g i , y } ∑ k exp { l g i , k } = 1 n ⋅ ( softmax ( l g i , ⋅ ) y ) − 1 \left(\frac{dloss}{dlg}\right)_{i,j}=-\frac{1}{n}+\frac{1}{n}\cdot\frac{\exp\{lg_{i,y}\}}{\sum_{k}\exp\{lg_{i,k}\}}=\frac{1}{n}\cdot(\text{softmax}(lg_{i,\cdot})_{y})-1 ( d l g d l oss ) i , j = − n 1 + n 1 ⋅ ∑ k exp { l g i , k } exp { l g i , y } = n 1 ⋅ ( softmax ( l g i , ⋅ ) y ) − 1
While the two cases are discussed separately, they share a common part in softmax
.
dlogits = F . softmax ( logits , 1 )
dlogits [ range ( n ), Yb ] -= 1
Finish calculating the rest derivatives. Read more
A typical training loop looks like this. One thing worth mentioning is the line optimizer.zero_grad()
.
Gradients in PyTorch are accumulated by default. Without zeroing them, gradients from multiple backpropagation steps would add up, leading to incorrect model updates.
In other words, zeroing gradients ensures that each backward pass calculates gradients based only on the current mini-batch.
for epoch in range ( num_epochs ):
for batch_X , batch_y in dataloader :
loss = criterion ( outputs , batch_y )
# Backward pass and optimize
epoch_loss += loss . item ()
# Record average loss for the epoch
avg_loss = epoch_loss / len ( dataloader )
These steps are part of each training loop iteration and help in adjusting the model’s parameters to improve its predictions.
optimizer.zero_grad():
Purpose : Clears old gradients.
What Happens : Resets the gradients of all model parameters to zero. This prevents the accumulation of gradients from multiple backpropagations, ensuring each mini-batch is computed independently.
loss.backward():
Purpose : Computes gradients.
What Happens : Performs backpropagation to compute the gradient of the loss with respect to each parameter (weight and bias). This is where the network learns, by calculating how much each parameter needs to change to minimize the loss.
optimizer.step():
Purpose : Updates parameters.
What Happens : Updates the model parameters based on the calculated gradients. This step uses the gradients computed during loss.backward()
to adjust the weights in an attempt to minimize the loss.