import torch from torch import nn from torch.utils.data import DataLoader from torchvision import datasets from torchvision.transforms import ToTensor import torch.nn.functional as F import time import matplotlib.pyplot as plt from torch.profiler import profile, record_function, ProfilerActivity torch.set_flush_denormal(True) training_data = datasets.MNIST( root="data", train=True, download=True, transform=ToTensor(), ) test_data = datasets.MNIST( root="data", train=False, download=True, transform=ToTensor(), ) batch_size = 64 train_dataloader = DataLoader(training_data, batch_size=batch_size, shuffle=True, num_workers=4) test_dataloader = DataLoader(test_data, batch_size=batch_size, num_workers=4) for input_data, labels in test_dataloader: print(f"Shape of input_data [N, C, H, W]: {input_data.shape}") print(f"Shape of labels: {labels.shape} (type: {labels.dtype})") break device = ( "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" ) # device = "cpu" print(f"Using {device} device") class ConvNetwork(nn.Module): def __init__(self): super().__init__() self.flatten = nn.Flatten() # self.conv_stack = nn.Sequential( # nn.Conv2d(1, 1, 3, 1, padding=0), # nn.GELU(), # # nn.MaxPool2d(2, 1), # # nn.GELU(), # nn.Conv2d(1, 1, 3, 1, padding=0), # nn.GELU(), # # nn.Conv2d(1, 1, 3, 1, padding=0, device=device), # # nn.GELU(), # # nn.Conv2d(1, 1, 3, 1, padding=0, device=device), # # nn.GELU(), # # nn.Conv2d(1, 1, 3, 1, padding=0, device=device), # # nn.GELU(), # # nn.Conv2d(1, 1, 3, 1, padding=0, device=device), # # nn.GELU(), # # nn.MaxPool2d(), # ) # # self.linear_stack = nn.Sequential( # nn.Linear(24 * 24, 64), # nn.GELU(), # # nn.Dropout(0.1), # nn.Linear(64, 32), # nn.GELU(), # # nn.Linear(256, 128), # # nn.GELU(), # # nn.Dropout(0.2), # # nn.Linear(128, 64), # # nn.GELU(), # # nn.Dropout(0.2), # nn.Linear(32, 10), # # nn.GELU(), # # nn.Linear(512, 10), # ) self.conv_stack = nn.Sequential( nn.Conv2d(1, 32, 3, 1), nn.GELU(), nn.Conv2d(32, 64, 3, 1), nn.GELU(), nn.MaxPool2d(2), nn.Conv2d(64, 128, 3, 1), nn.GELU(), nn.Conv2d(128, 128, 3, 1), nn.GELU(), nn.MaxPool2d(2) ) self.linear_stack = nn.Sequential( nn.Linear(128 * 4 * 4, 128), nn.GELU(), nn.Linear(128, 10) ) def forward(self, x): x = self.conv_stack(x) # x = self.flatten(x) x = x.view(x.size(0), -1) logits = self.linear_stack(x) return logits feed_forward_model = ConvNetwork().to(device) print(feed_forward_model) loss_fn = nn.MSELoss() loss_fn = nn.CrossEntropyLoss() sgd_optimizer = torch.optim.SGD(feed_forward_model.parameters(), lr=1e-3, momentum=0.9) adam = torch.optim.AdamW(feed_forward_model.parameters(), lr=1e-3) def train(dataloader, model, loss_fn, optimizer): size = len(dataloader.dataset) model.train() for batch, (X, y) in enumerate(dataloader): X, y = X.to(device), y.to(device) # with profile(activities=[ # ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof: # with record_function("model_inference"): pred = model(X) # print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) loss = loss_fn(pred, y) #loss_fn(pred, F.one_hot(y, num_classes=10).float()) # Backpropagation loss.backward() optimizer.step() optimizer.zero_grad() if batch % 100 == 0: loss = loss.item() current = (batch + 1) * len(X) print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]") def test(dataloader, model, loss_fn): size = len(dataloader.dataset) num_batches = len(dataloader) model.eval() test_loss, correct = 0, 0 misclassified_images = [] with torch.no_grad(): for X, y in dataloader: X, y = X.to(device), y.to(device) pred = model(X) test_loss += loss_fn(pred, y).item() correct += (pred.argmax(1) == y).type(torch.float).sum().item() # Collect misclassified images misclassified_images.extend( [(img, true_label, pred_label) for img, true_label, pred_label in zip(X, y, pred.argmax(1)) if true_label != pred_label]) test_loss /= num_batches correct /= size print(f"Test Error: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \n") return misclassified_images epochs = 20 start_time = time.time() for t in range(epochs): print( f"Epoch {t + 1} (lr: {sgd_optimizer.state_dict()['param_groups'][0]['lr']})\n-------------------------------") train(train_dataloader, feed_forward_model, loss_fn, adam) test(test_dataloader, feed_forward_model, loss_fn) print("Done!") end_time = time.time() duration = end_time - start_time print(f"Training took {duration:.2f} seconds") misclassified_images = test(test_dataloader, feed_forward_model, loss_fn) print(len(misclassified_images)) # # Visualizing misclassified images # plt.figure(figsize=(10, 10)) # for i, (image, true_label, pred_label) in enumerate(misclassified_images): # # plt.subplot(5, 5, i + 1) # plt.imshow(image.cpu().squeeze(), cmap='gray') # plt.title(f'True: {true_label.item()}\nPredicted: {pred_label.item()}') # plt.axis('off') # plt.show()