Files
ml/ct.py
2024-05-15 00:04:21 +02:00

200 lines
5.9 KiB
Python

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
import torch.nn.functional as F
import time
import matplotlib.pyplot as plt
from torch.profiler import profile, record_function, ProfilerActivity
torch.set_flush_denormal(True)
training_data = datasets.MNIST(
root="data",
train=True,
download=True,
transform=ToTensor(),
)
test_data = datasets.MNIST(
root="data",
train=False,
download=True,
transform=ToTensor(),
)
batch_size = 64
train_dataloader = DataLoader(training_data, batch_size=batch_size, shuffle=True, num_workers=4)
test_dataloader = DataLoader(test_data, batch_size=batch_size, num_workers=4)
for input_data, labels in test_dataloader:
print(f"Shape of input_data [N, C, H, W]: {input_data.shape}")
print(f"Shape of labels: {labels.shape} (type: {labels.dtype})")
break
device = (
"cuda"
if torch.cuda.is_available()
else "mps"
if torch.backends.mps.is_available()
else "cpu"
)
# device = "cpu"
print(f"Using {device} device")
class ConvNetwork(nn.Module):
def __init__(self):
super().__init__()
self.flatten = nn.Flatten()
# self.conv_stack = nn.Sequential(
# nn.Conv2d(1, 1, 3, 1, padding=0),
# nn.GELU(),
# # nn.MaxPool2d(2, 1),
# # nn.GELU(),
# nn.Conv2d(1, 1, 3, 1, padding=0),
# nn.GELU(),
# # nn.Conv2d(1, 1, 3, 1, padding=0, device=device),
# # nn.GELU(),
# # nn.Conv2d(1, 1, 3, 1, padding=0, device=device),
# # nn.GELU(),
# # nn.Conv2d(1, 1, 3, 1, padding=0, device=device),
# # nn.GELU(),
# # nn.Conv2d(1, 1, 3, 1, padding=0, device=device),
# # nn.GELU(),
# # nn.MaxPool2d(),
# )
#
# self.linear_stack = nn.Sequential(
# nn.Linear(24 * 24, 64),
# nn.GELU(),
# # nn.Dropout(0.1),
# nn.Linear(64, 32),
# nn.GELU(),
# # nn.Linear(256, 128),
# # nn.GELU(),
# # nn.Dropout(0.2),
# # nn.Linear(128, 64),
# # nn.GELU(),
# # nn.Dropout(0.2),
# nn.Linear(32, 10),
# # nn.GELU(),
# # nn.Linear(512, 10),
# )
self.conv_stack = nn.Sequential(
nn.Conv2d(1, 32, 3, 1),
nn.GELU(),
nn.Conv2d(32, 64, 3, 1),
nn.GELU(),
nn.MaxPool2d(2),
nn.Conv2d(64, 128, 3, 1),
nn.GELU(),
nn.Conv2d(128, 128, 3, 1),
nn.GELU(),
nn.MaxPool2d(2)
)
self.linear_stack = nn.Sequential(
nn.Linear(128 * 4 * 4, 128),
nn.GELU(),
nn.Linear(128, 10)
)
def forward(self, x):
x = self.conv_stack(x)
# x = self.flatten(x)
x = x.view(x.size(0), -1)
logits = self.linear_stack(x)
return logits
feed_forward_model = ConvNetwork().to(device)
print(feed_forward_model)
loss_fn = nn.MSELoss()
loss_fn = nn.CrossEntropyLoss()
sgd_optimizer = torch.optim.SGD(feed_forward_model.parameters(), lr=1e-3, momentum=0.9)
adam = torch.optim.AdamW(feed_forward_model.parameters(), lr=1e-3)
def train(dataloader, model, loss_fn, optimizer):
size = len(dataloader.dataset)
model.train()
for batch, (X, y) in enumerate(dataloader):
X, y = X.to(device), y.to(device)
# with profile(activities=[
# ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof:
# with record_function("model_inference"):
pred = model(X)
# print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
loss = loss_fn(pred, y) #loss_fn(pred, F.one_hot(y, num_classes=10).float())
# Backpropagation
loss.backward()
optimizer.step()
optimizer.zero_grad()
if batch % 100 == 0:
loss = loss.item()
current = (batch + 1) * len(X)
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
def test(dataloader, model, loss_fn):
size = len(dataloader.dataset)
num_batches = len(dataloader)
model.eval()
test_loss, correct = 0, 0
misclassified_images = []
with torch.no_grad():
for X, y in dataloader:
X, y = X.to(device), y.to(device)
pred = model(X)
test_loss += loss_fn(pred, y).item()
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
# Collect misclassified images
misclassified_images.extend(
[(img, true_label, pred_label) for img, true_label, pred_label in zip(X, y, pred.argmax(1)) if
true_label != pred_label])
test_loss /= num_batches
correct /= size
print(f"Test Error: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
return misclassified_images
epochs = 20
start_time = time.time()
for t in range(epochs):
print(
f"Epoch {t + 1} (lr: {sgd_optimizer.state_dict()['param_groups'][0]['lr']})\n-------------------------------")
train(train_dataloader,
feed_forward_model,
loss_fn,
adam)
test(test_dataloader,
feed_forward_model,
loss_fn)
print("Done!")
end_time = time.time()
duration = end_time - start_time
print(f"Training took {duration:.2f} seconds")
misclassified_images = test(test_dataloader, feed_forward_model, loss_fn)
print(len(misclassified_images))
# # Visualizing misclassified images
# plt.figure(figsize=(10, 10))
# for i, (image, true_label, pred_label) in enumerate(misclassified_images):
# # plt.subplot(5, 5, i + 1)
# plt.imshow(image.cpu().squeeze(), cmap='gray')
# plt.title(f'True: {true_label.item()}\nPredicted: {pred_label.item()}')
# plt.axis('off')
# plt.show()