We appreciate your patience while we actively develop and enhance our content for a better experience.
Sample Code
import torch import torchvision import torchvision.transforms as transforms import torch.nn as nn import torch.optim as optim class SimpleNN(nn.Module): def __init__(self): super(SimpleNN, self).__init__() self.flatten = nn.Flatten() self.linear_relu_stack = nn.Sequential( nn.Linear(28*28, 128), nn.ReLU(), nn.Linear(128, 10), nn.ReLU() ) def forward(self, x): x = self.flatten(x) logits = self.linear_relu_stack(x) return logits def main(): print("This program demonstrates how to use PyTorch to create a neural network to classify handwritten digits from the MNIST dataset.") # Load the MNIST dataset transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]) train_set = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform) train_loader = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True) test_set = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform) test_loader = torch.utils.data.DataLoader(test_set, batch_size=64, shuffle=False) # Create the neural network model model = SimpleNN() # Set the loss function and optimizer loss_fn = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=0.001) # Train the model num_epochs = 5 for epoch in range(num_epochs): for images, labels in train_loader: # Zero the gradients optimizer.zero_grad() # Forward pass outputs = model(images) loss = loss_fn(outputs, labels) # Backward pass loss.backward() optimizer.step() print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}") # Evaluate the model on the test dataset model.eval() correct = 0 total = 0 with torch.no_grad(): for images, labels in test_loader: outputs = model(images) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() accuracy = 100 * correct / total print(f"Accuracy of the neural network on the test set: {accuracy:.2f}%") if __name__ == "__main__": main()