def train(model, train_loader, epochs=10): model.train() for epoch in range(epochs): total_loss = 0 for batch_idx, (data, target) in enumerate(train_loader): optimizer.zero_grad() output = model(data) loss = criterion(output, target) loss.backward() optimizer.step() total_loss += loss.item() print(f'Epoch {epoch+1}, Loss: {total_loss/len(train_loader):.4f}')
from torchvision import datasets, transforms
transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)) ])
train_data = datasets.MNIST('data', train=True, download=True, transform=transform) train_loader = torch.utils.data.DataLoader(train_data, batch_size=64, shuffle=True)
train(model, train_loader)
|