got over 99%
This commit is contained in:
199
ct.py
Normal file
199
ct.py
Normal file
@@ -0,0 +1,199 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision import datasets
|
||||
from torchvision.transforms import ToTensor
|
||||
import torch.nn.functional as F
|
||||
import time
|
||||
import matplotlib.pyplot as plt
|
||||
from torch.profiler import profile, record_function, ProfilerActivity
|
||||
|
||||
torch.set_flush_denormal(True)
|
||||
|
||||
training_data = datasets.MNIST(
|
||||
root="data",
|
||||
train=True,
|
||||
download=True,
|
||||
transform=ToTensor(),
|
||||
)
|
||||
|
||||
test_data = datasets.MNIST(
|
||||
root="data",
|
||||
train=False,
|
||||
download=True,
|
||||
transform=ToTensor(),
|
||||
)
|
||||
|
||||
batch_size = 64
|
||||
|
||||
train_dataloader = DataLoader(training_data, batch_size=batch_size, shuffle=True, num_workers=4)
|
||||
test_dataloader = DataLoader(test_data, batch_size=batch_size, num_workers=4)
|
||||
|
||||
for input_data, labels in test_dataloader:
|
||||
print(f"Shape of input_data [N, C, H, W]: {input_data.shape}")
|
||||
print(f"Shape of labels: {labels.shape} (type: {labels.dtype})")
|
||||
break
|
||||
|
||||
device = (
|
||||
"cuda"
|
||||
if torch.cuda.is_available()
|
||||
else "mps"
|
||||
if torch.backends.mps.is_available()
|
||||
else "cpu"
|
||||
)
|
||||
|
||||
# device = "cpu"
|
||||
print(f"Using {device} device")
|
||||
|
||||
|
||||
class ConvNetwork(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.flatten = nn.Flatten()
|
||||
# self.conv_stack = nn.Sequential(
|
||||
# nn.Conv2d(1, 1, 3, 1, padding=0),
|
||||
# nn.GELU(),
|
||||
# # nn.MaxPool2d(2, 1),
|
||||
# # nn.GELU(),
|
||||
# nn.Conv2d(1, 1, 3, 1, padding=0),
|
||||
# nn.GELU(),
|
||||
# # nn.Conv2d(1, 1, 3, 1, padding=0, device=device),
|
||||
# # nn.GELU(),
|
||||
# # nn.Conv2d(1, 1, 3, 1, padding=0, device=device),
|
||||
# # nn.GELU(),
|
||||
# # nn.Conv2d(1, 1, 3, 1, padding=0, device=device),
|
||||
# # nn.GELU(),
|
||||
# # nn.Conv2d(1, 1, 3, 1, padding=0, device=device),
|
||||
# # nn.GELU(),
|
||||
# # nn.MaxPool2d(),
|
||||
# )
|
||||
#
|
||||
# self.linear_stack = nn.Sequential(
|
||||
# nn.Linear(24 * 24, 64),
|
||||
# nn.GELU(),
|
||||
# # nn.Dropout(0.1),
|
||||
# nn.Linear(64, 32),
|
||||
# nn.GELU(),
|
||||
# # nn.Linear(256, 128),
|
||||
# # nn.GELU(),
|
||||
# # nn.Dropout(0.2),
|
||||
# # nn.Linear(128, 64),
|
||||
# # nn.GELU(),
|
||||
# # nn.Dropout(0.2),
|
||||
# nn.Linear(32, 10),
|
||||
# # nn.GELU(),
|
||||
# # nn.Linear(512, 10),
|
||||
# )
|
||||
|
||||
self.conv_stack = nn.Sequential(
|
||||
nn.Conv2d(1, 32, 3, 1),
|
||||
nn.GELU(),
|
||||
nn.Conv2d(32, 64, 3, 1),
|
||||
nn.GELU(),
|
||||
nn.MaxPool2d(2),
|
||||
nn.Conv2d(64, 128, 3, 1),
|
||||
nn.GELU(),
|
||||
nn.Conv2d(128, 128, 3, 1),
|
||||
nn.GELU(),
|
||||
nn.MaxPool2d(2)
|
||||
)
|
||||
self.linear_stack = nn.Sequential(
|
||||
nn.Linear(128 * 4 * 4, 128),
|
||||
nn.GELU(),
|
||||
nn.Linear(128, 10)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv_stack(x)
|
||||
# x = self.flatten(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
logits = self.linear_stack(x)
|
||||
return logits
|
||||
|
||||
|
||||
feed_forward_model = ConvNetwork().to(device)
|
||||
print(feed_forward_model)
|
||||
|
||||
loss_fn = nn.MSELoss()
|
||||
loss_fn = nn.CrossEntropyLoss()
|
||||
sgd_optimizer = torch.optim.SGD(feed_forward_model.parameters(), lr=1e-3, momentum=0.9)
|
||||
adam = torch.optim.AdamW(feed_forward_model.parameters(), lr=1e-3)
|
||||
|
||||
def train(dataloader, model, loss_fn, optimizer):
|
||||
size = len(dataloader.dataset)
|
||||
model.train()
|
||||
for batch, (X, y) in enumerate(dataloader):
|
||||
X, y = X.to(device), y.to(device)
|
||||
|
||||
# with profile(activities=[
|
||||
# ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof:
|
||||
# with record_function("model_inference"):
|
||||
pred = model(X)
|
||||
# print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
|
||||
|
||||
loss = loss_fn(pred, y) #loss_fn(pred, F.one_hot(y, num_classes=10).float())
|
||||
|
||||
# Backpropagation
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
if batch % 100 == 0:
|
||||
loss = loss.item()
|
||||
current = (batch + 1) * len(X)
|
||||
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
|
||||
|
||||
|
||||
def test(dataloader, model, loss_fn):
|
||||
size = len(dataloader.dataset)
|
||||
num_batches = len(dataloader)
|
||||
model.eval()
|
||||
test_loss, correct = 0, 0
|
||||
misclassified_images = []
|
||||
with torch.no_grad():
|
||||
for X, y in dataloader:
|
||||
X, y = X.to(device), y.to(device)
|
||||
|
||||
pred = model(X)
|
||||
|
||||
test_loss += loss_fn(pred, y).item()
|
||||
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
|
||||
# Collect misclassified images
|
||||
misclassified_images.extend(
|
||||
[(img, true_label, pred_label) for img, true_label, pred_label in zip(X, y, pred.argmax(1)) if
|
||||
true_label != pred_label])
|
||||
test_loss /= num_batches
|
||||
correct /= size
|
||||
print(f"Test Error: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
|
||||
return misclassified_images
|
||||
|
||||
|
||||
|
||||
epochs = 20
|
||||
start_time = time.time()
|
||||
for t in range(epochs):
|
||||
print(
|
||||
f"Epoch {t + 1} (lr: {sgd_optimizer.state_dict()['param_groups'][0]['lr']})\n-------------------------------")
|
||||
train(train_dataloader,
|
||||
feed_forward_model,
|
||||
loss_fn,
|
||||
adam)
|
||||
test(test_dataloader,
|
||||
feed_forward_model,
|
||||
loss_fn)
|
||||
print("Done!")
|
||||
end_time = time.time()
|
||||
duration = end_time - start_time
|
||||
print(f"Training took {duration:.2f} seconds")
|
||||
|
||||
misclassified_images = test(test_dataloader, feed_forward_model, loss_fn)
|
||||
print(len(misclassified_images))
|
||||
|
||||
# # Visualizing misclassified images
|
||||
# plt.figure(figsize=(10, 10))
|
||||
# for i, (image, true_label, pred_label) in enumerate(misclassified_images):
|
||||
# # plt.subplot(5, 5, i + 1)
|
||||
# plt.imshow(image.cpu().squeeze(), cmap='gray')
|
||||
# plt.title(f'True: {true_label.item()}\nPredicted: {pred_label.item()}')
|
||||
# plt.axis('off')
|
||||
# plt.show()
|
||||
Reference in New Issue
Block a user