200 lines
5.9 KiB
Python
200 lines
5.9 KiB
Python
import torch
|
|
from torch import nn
|
|
from torch.utils.data import DataLoader
|
|
from torchvision import datasets
|
|
from torchvision.transforms import ToTensor
|
|
import torch.nn.functional as F
|
|
import time
|
|
import matplotlib.pyplot as plt
|
|
from torch.profiler import profile, record_function, ProfilerActivity
|
|
|
|
torch.set_flush_denormal(True)
|
|
|
|
training_data = datasets.MNIST(
|
|
root="data",
|
|
train=True,
|
|
download=True,
|
|
transform=ToTensor(),
|
|
)
|
|
|
|
test_data = datasets.MNIST(
|
|
root="data",
|
|
train=False,
|
|
download=True,
|
|
transform=ToTensor(),
|
|
)
|
|
|
|
batch_size = 64
|
|
|
|
train_dataloader = DataLoader(training_data, batch_size=batch_size, shuffle=True, num_workers=4)
|
|
test_dataloader = DataLoader(test_data, batch_size=batch_size, num_workers=4)
|
|
|
|
for input_data, labels in test_dataloader:
|
|
print(f"Shape of input_data [N, C, H, W]: {input_data.shape}")
|
|
print(f"Shape of labels: {labels.shape} (type: {labels.dtype})")
|
|
break
|
|
|
|
device = (
|
|
"cuda"
|
|
if torch.cuda.is_available()
|
|
else "mps"
|
|
if torch.backends.mps.is_available()
|
|
else "cpu"
|
|
)
|
|
|
|
# device = "cpu"
|
|
print(f"Using {device} device")
|
|
|
|
|
|
class ConvNetwork(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.flatten = nn.Flatten()
|
|
# self.conv_stack = nn.Sequential(
|
|
# nn.Conv2d(1, 1, 3, 1, padding=0),
|
|
# nn.GELU(),
|
|
# # nn.MaxPool2d(2, 1),
|
|
# # nn.GELU(),
|
|
# nn.Conv2d(1, 1, 3, 1, padding=0),
|
|
# nn.GELU(),
|
|
# # nn.Conv2d(1, 1, 3, 1, padding=0, device=device),
|
|
# # nn.GELU(),
|
|
# # nn.Conv2d(1, 1, 3, 1, padding=0, device=device),
|
|
# # nn.GELU(),
|
|
# # nn.Conv2d(1, 1, 3, 1, padding=0, device=device),
|
|
# # nn.GELU(),
|
|
# # nn.Conv2d(1, 1, 3, 1, padding=0, device=device),
|
|
# # nn.GELU(),
|
|
# # nn.MaxPool2d(),
|
|
# )
|
|
#
|
|
# self.linear_stack = nn.Sequential(
|
|
# nn.Linear(24 * 24, 64),
|
|
# nn.GELU(),
|
|
# # nn.Dropout(0.1),
|
|
# nn.Linear(64, 32),
|
|
# nn.GELU(),
|
|
# # nn.Linear(256, 128),
|
|
# # nn.GELU(),
|
|
# # nn.Dropout(0.2),
|
|
# # nn.Linear(128, 64),
|
|
# # nn.GELU(),
|
|
# # nn.Dropout(0.2),
|
|
# nn.Linear(32, 10),
|
|
# # nn.GELU(),
|
|
# # nn.Linear(512, 10),
|
|
# )
|
|
|
|
self.conv_stack = nn.Sequential(
|
|
nn.Conv2d(1, 32, 3, 1),
|
|
nn.GELU(),
|
|
nn.Conv2d(32, 64, 3, 1),
|
|
nn.GELU(),
|
|
nn.MaxPool2d(2),
|
|
nn.Conv2d(64, 128, 3, 1),
|
|
nn.GELU(),
|
|
nn.Conv2d(128, 128, 3, 1),
|
|
nn.GELU(),
|
|
nn.MaxPool2d(2)
|
|
)
|
|
self.linear_stack = nn.Sequential(
|
|
nn.Linear(128 * 4 * 4, 128),
|
|
nn.GELU(),
|
|
nn.Linear(128, 10)
|
|
)
|
|
|
|
def forward(self, x):
|
|
x = self.conv_stack(x)
|
|
# x = self.flatten(x)
|
|
x = x.view(x.size(0), -1)
|
|
logits = self.linear_stack(x)
|
|
return logits
|
|
|
|
|
|
feed_forward_model = ConvNetwork().to(device)
|
|
print(feed_forward_model)
|
|
|
|
loss_fn = nn.MSELoss()
|
|
loss_fn = nn.CrossEntropyLoss()
|
|
sgd_optimizer = torch.optim.SGD(feed_forward_model.parameters(), lr=1e-3, momentum=0.9)
|
|
adam = torch.optim.AdamW(feed_forward_model.parameters(), lr=1e-3)
|
|
|
|
def train(dataloader, model, loss_fn, optimizer):
|
|
size = len(dataloader.dataset)
|
|
model.train()
|
|
for batch, (X, y) in enumerate(dataloader):
|
|
X, y = X.to(device), y.to(device)
|
|
|
|
# with profile(activities=[
|
|
# ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof:
|
|
# with record_function("model_inference"):
|
|
pred = model(X)
|
|
# print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
|
|
|
|
loss = loss_fn(pred, y) #loss_fn(pred, F.one_hot(y, num_classes=10).float())
|
|
|
|
# Backpropagation
|
|
loss.backward()
|
|
optimizer.step()
|
|
optimizer.zero_grad()
|
|
|
|
if batch % 100 == 0:
|
|
loss = loss.item()
|
|
current = (batch + 1) * len(X)
|
|
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
|
|
|
|
|
|
def test(dataloader, model, loss_fn):
|
|
size = len(dataloader.dataset)
|
|
num_batches = len(dataloader)
|
|
model.eval()
|
|
test_loss, correct = 0, 0
|
|
misclassified_images = []
|
|
with torch.no_grad():
|
|
for X, y in dataloader:
|
|
X, y = X.to(device), y.to(device)
|
|
|
|
pred = model(X)
|
|
|
|
test_loss += loss_fn(pred, y).item()
|
|
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
|
|
# Collect misclassified images
|
|
misclassified_images.extend(
|
|
[(img, true_label, pred_label) for img, true_label, pred_label in zip(X, y, pred.argmax(1)) if
|
|
true_label != pred_label])
|
|
test_loss /= num_batches
|
|
correct /= size
|
|
print(f"Test Error: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
|
|
return misclassified_images
|
|
|
|
|
|
|
|
epochs = 20
|
|
start_time = time.time()
|
|
for t in range(epochs):
|
|
print(
|
|
f"Epoch {t + 1} (lr: {sgd_optimizer.state_dict()['param_groups'][0]['lr']})\n-------------------------------")
|
|
train(train_dataloader,
|
|
feed_forward_model,
|
|
loss_fn,
|
|
adam)
|
|
test(test_dataloader,
|
|
feed_forward_model,
|
|
loss_fn)
|
|
print("Done!")
|
|
end_time = time.time()
|
|
duration = end_time - start_time
|
|
print(f"Training took {duration:.2f} seconds")
|
|
|
|
misclassified_images = test(test_dataloader, feed_forward_model, loss_fn)
|
|
print(len(misclassified_images))
|
|
|
|
# # Visualizing misclassified images
|
|
# plt.figure(figsize=(10, 10))
|
|
# for i, (image, true_label, pred_label) in enumerate(misclassified_images):
|
|
# # plt.subplot(5, 5, i + 1)
|
|
# plt.imshow(image.cpu().squeeze(), cmap='gray')
|
|
# plt.title(f'True: {true_label.item()}\nPredicted: {pred_label.item()}')
|
|
# plt.axis('off')
|
|
# plt.show()
|