Skip to content

PyTorch

Useful Snippets

torch.no_grad() and model.eval()

In PyTorch, torch.no_grad() and model.eval() are both used when evaluating a model, but they serve different purposes:

torch.no_grad():

  • Function: This is a context manager that disables gradient calculation.
  • Purpose: It is used to reduce memory consumption and speed up computations during inference, as gradients are not needed for evaluation.
  • How it works: Inside the with torch.no_grad() block, PyTorch will not track operations for backpropagation, which saves memory and computation time.

model.eval():

  • Function: This method sets the model to evaluation mode.
  • Purpose: It changes the behavior of certain layers that behave differently during training and evaluation, such as dropout and batch normalization.
  • How it works:
    • Dropout: During training, dropout randomly deactivates neurons to prevent overfitting. In evaluation mode, dropout is turned off, allowing all neurons to participate.
    • Batch Normalization: During training, batch normalization uses batch statistics to normalize the activations. In evaluation mode, it uses running statistics calculated during training.

when to use each

torch.no_grad():

Use this whenever you are performing inference and do not need to calculate gradients.

model.eval():

Use this to put your model in evaluation mode, which is essential when evaluating your model’s performance.

best practice

Combine both: It is common to use both model.eval() and torch.no_grad() together during evaluation:

1
with torch.no_grad():
2
model.eval()
3
# Perform evaluation
4
5
# OR use `torch.no_grad()` as a decorator
6
@torch.no_grad()
7
def some_method(self,):

register_buffer

Buffers won’t be returned in model.parameters(), so that the optimizer won’t have a chance to update them. Another one is that all buffers and parameters will be pushed to the device, if called on the parent model.

1
class MyModel(nn.Module):
2
def __init__(self):
3
super().__init__()
4
self.my_tensor = torch.randn(1)
5
self.register_buffer('my_buffer', torch.randn(1))
6
self.my_param = nn.Parameter(torch.randn(1))
7
8
model = MyModel()
9
print(model.my_tensor)
10
>>> tensor([-1.4624])
11
print(model.state_dict())
12
>>> OrderedDict([('my_param', tensor([-1.7173])), ('my_buffer', tensor([0.7523]))])
13
14
model.cuda()
15
print(model.my_tensor)
16
>>> tensor([-1.4624])
17
print(model.state_dict())
18
>>> OrderedDict([('my_param', tensor([-1.7173], device='cuda:0')), ('my_buffer', tensor([0.7523], device='cuda:0'))])

As shown above in the console output, model.my_tensor is still on the CPU, where it was created, while all parameters and buffers were pushed to the GPU after calling model.cuda().

automatic mixed precision

1
with torch.cuda.amp.autocast():
2
outputs = model(inputs) # Operations here use mixed precision
3
loss = compute_loss(outputs)

AMP is a feature in PyTorch designed to optimize training performance by using lower precision (e.g., float16) where possible. This reduces memory usage and improves computational speed, particularly on GPUs, without significantly affecting model accuracy. gpt-4o

Key Points:

  • Precision Types: AMP uses mixed precision, typically combining float16 (for faster computation and less memory) and float32 (for operations where reduced precision could lead to errors).
  • Usage: It is often enabled through a context manager, torch.amp.autocast, which scopes the precision of computations within the block.
  • Benefits: AMP can lead to faster training times and reduced memory usage on supported hardware like NVIDIA GPUs.

Example:

1
device = 'cuda' if torch.cuda.is_available() else 'cpu'
2
# 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler
3
dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16'
4
device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast
5
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
6
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)

Functions

einsum

1
"""
2
https://pytorch.org/docs/stable/generated/torch.einsum.html
3
"""
4
>>> # trace
5
>>> torch.einsum('ii', torch.randn(4, 4))
6
tensor(-1.2104)
7
8
>>> # diagonal
9
>>> torch.einsum('ii->i', torch.randn(4, 4))
10
tensor([-0.1034, 0.7952, -0.2433, 0.4545])
11
12
>>> # outer product
13
>>> x = torch.randn(5)
14
>>> y = torch.randn(4)
15
>>> torch.einsum('i,j->ij', x, y)
16
tensor([[ 0.1156, -0.2897, -0.3918, 0.4963],
17
[-0.3744, 0.9381, 1.2685, -1.6070],
18
[ 0.7208, -1.8058, -2.4419, 3.0936],
19
[ 0.1713, -0.4291, -0.5802, 0.7350],
20
[ 0.5704, -1.4290, -1.9323, 2.4480]])
21
22
>>> # batch matrix multiplication
23
>>> As = torch.randn(3, 2, 5)
24
>>> Bs = torch.randn(3, 5, 4)
25
>>> torch.einsum('bij,bjk->bik', As, Bs)
26
tensor([[[-1.0564, -1.5904, 3.2023, 3.1271],
27
[-1.6706, -0.8097, -0.8025, -2.1183]],
28
29
[[ 4.2239, 0.3107, -0.5756, -0.2354],
30
[-1.4558, -0.3460, 1.5087, -0.8530]],
31
32
[[ 2.8153, 1.8787, -4.3839, -1.2112],
33
[ 0.3728, -2.1131, 0.0921, 0.8305]]])
34
35
>>> # with sublist format and ellipsis
36
>>> torch.einsum(As, [..., 0, 1], Bs, [..., 1, 2], [..., 0, 2])
37
tensor([[[-1.0564, -1.5904, 3.2023, 3.1271],
38
[-1.6706, -0.8097, -0.8025, -2.1183]],
39
40
[[ 4.2239, 0.3107, -0.5756, -0.2354],
41
[-1.4558, -0.3460, 1.5087, -0.8530]],
42
43
[[ 2.8153, 1.8787, -4.3839, -1.2112],
44
[ 0.3728, -2.1131, 0.0921, 0.8305]]])
45
46
>>> # batch permute
47
>>> A = torch.randn(2, 3, 4, 5)
48
>>> torch.einsum('...ij->...ji', A).shape
49
torch.Size([2, 3, 5, 4])
50
51
>>> # equivalent to torch.nn.functional.bilinear
52
>>> A = torch.randn(3, 5, 4)
53
>>> l = torch.randn(2, 5)
54
>>> r = torch.randn(2, 4)
55
>>> torch.einsum('bn,anm,bm->ba', l, A, r)
56
tensor([[-0.3430, -5.2405, 0.4494],
57
[ 0.3311, 5.5201, -3.0356]])

Misc.

The weight matrix for torch.nn.Linear is stored in a transpose fashion for computational purpose.

1
a = torch.nn.Linear(4,5)
2
b = torch.nn.Embedding(4,5)
3
print(a)
4
print(b)
5
print(a.weight.shape)
6
print(b.weight.shape)

yields output

1
Linear(in_features=4, out_features=5, bias=True)
2
Embedding(4, 5)
3
torch.Size([5, 4])
4
torch.Size([4, 5])