Site icon SolarDevs

Pytorch Training Loop

Pytorch

As a newbie on pytorch maybe you’re wondering what the hell is the pytorch training loop. Tensorflow doesn’t have that!

I won’t go on the details of why this is a thing in Pytorch, but it is certainly much more powerful than others. Why? because you have control on every step of the process, including debugging and modifying the behavior of the training process (as I said, I won’t go over details.. Here is a really awesome blog explaining pros and cons of each framework)

The basic Pytorch Training Loop is composed of 5 steps.

num_epochs = 100

for epoch in range(num_epochs):
    for X, Y in data:
        # 1. clear gradients
        model.zero_grad()

        # 2. forward pass
        y_pred = model(X)

        # 3. compute loss
        loss = loss_function(y_pred, Y)

        # 4. compute gradients
        loss.backward()

        # 5. adjust learnable parameters
        optimizer.step()

1. Clear Gradients

We need to clear the Tensor gradients (in case there are) because every time we compute gradients, they get accumulated, not replaced. So this means in case we didn’t clear the gradients, we would compute the sum from previous runs, which is something we don’t want.

2. Forward Pass

We send the data forward through the network.

3. Compute Loss

We calculate the loss (error) of our forward output vs the real Y label

4. Compute Gradients

We compute backwards that error we got, this means we’re computing the gradients of our loss function with respect to all our parameters (We’re not discussing calculus here)

5. Adjust Learnable Parameters

Once we calculated how wrong our predictions are, we will have new gradients! (thats why we did step 1). So we will do 1 step on that direction based on the optimizer algorithm.

That’s the basic training loop. Of course there are many more things you can do, but I will leave those for another blog.

I hope this was helpful to you to understand Pytorch’s training loop.

Exit mobile version