Introduction to PyTorch

PyTorch - Chapter 1

PyTorch Tutorial

We appreciate your patience while we actively develop and enhance our content for a better experience.

Sample Code
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim

class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 128),
            nn.Linear(128, 10),

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

def main():
    print("This program demonstrates how to use PyTorch to create a neural network to classify handwritten digits from the MNIST dataset.")

    # Load the MNIST dataset
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

    train_set = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    train_loader =, batch_size=64, shuffle=True)

    test_set = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
    test_loader =, batch_size=64, shuffle=False)

    # Create the neural network model
    model = SimpleNN()

    # Set the loss function and optimizer
    loss_fn = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    # Train the model
    num_epochs = 5
    for epoch in range(num_epochs):
        for images, labels in train_loader:
            # Zero the gradients

            # Forward pass
            outputs = model(images)
            loss = loss_fn(outputs, labels)

            # Backward pass

        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")

    # Evaluate the model on the test dataset
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            outputs = model(images)
            _, predicted = torch.max(, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f"Accuracy of the neural network on the test set: {accuracy:.2f}%")

if __name__ == "__main__":