got over 99%
This commit is contained in:
8
.idea/.gitignore
generated
vendored
Normal file
8
.idea/.gitignore
generated
vendored
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
# Default ignored files
|
||||||
|
/shelf/
|
||||||
|
/workspace.xml
|
||||||
|
# Editor-based HTTP Client requests
|
||||||
|
/httpRequests/
|
||||||
|
# Datasource local storage ignored files
|
||||||
|
/dataSources/
|
||||||
|
/dataSources.local.xml
|
||||||
6
.idea/inspectionProfiles/profiles_settings.xml
generated
Normal file
6
.idea/inspectionProfiles/profiles_settings.xml
generated
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
<component name="InspectionProjectProfileManager">
|
||||||
|
<settings>
|
||||||
|
<option name="USE_PROJECT_PROFILE" value="false" />
|
||||||
|
<version value="1.0" />
|
||||||
|
</settings>
|
||||||
|
</component>
|
||||||
7
.idea/misc.xml
generated
Normal file
7
.idea/misc.xml
generated
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="Black">
|
||||||
|
<option name="sdkName" value="Python 3.10" />
|
||||||
|
</component>
|
||||||
|
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.10" project-jdk-type="Python SDK" />
|
||||||
|
</project>
|
||||||
8
.idea/ml.iml
generated
Normal file
8
.idea/ml.iml
generated
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<module type="PYTHON_MODULE" version="4">
|
||||||
|
<component name="NewModuleRootManager">
|
||||||
|
<content url="file://$MODULE_DIR$" />
|
||||||
|
<orderEntry type="inheritedJdk" />
|
||||||
|
<orderEntry type="sourceFolder" forTests="false" />
|
||||||
|
</component>
|
||||||
|
</module>
|
||||||
8
.idea/modules.xml
generated
Normal file
8
.idea/modules.xml
generated
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="ProjectModuleManager">
|
||||||
|
<modules>
|
||||||
|
<module fileurl="file://$PROJECT_DIR$/.idea/ml.iml" filepath="$PROJECT_DIR$/.idea/ml.iml" />
|
||||||
|
</modules>
|
||||||
|
</component>
|
||||||
|
</project>
|
||||||
6
.idea/other.xml
generated
Normal file
6
.idea/other.xml
generated
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="PySciProjectComponent">
|
||||||
|
<option name="PY_INTERACTIVE_PLOTS_SUGGESTED" value="true" />
|
||||||
|
</component>
|
||||||
|
</project>
|
||||||
199
ct.py
Normal file
199
ct.py
Normal file
@@ -0,0 +1,199 @@
|
|||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from torchvision import datasets
|
||||||
|
from torchvision.transforms import ToTensor
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import time
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from torch.profiler import profile, record_function, ProfilerActivity
|
||||||
|
|
||||||
|
torch.set_flush_denormal(True)
|
||||||
|
|
||||||
|
training_data = datasets.MNIST(
|
||||||
|
root="data",
|
||||||
|
train=True,
|
||||||
|
download=True,
|
||||||
|
transform=ToTensor(),
|
||||||
|
)
|
||||||
|
|
||||||
|
test_data = datasets.MNIST(
|
||||||
|
root="data",
|
||||||
|
train=False,
|
||||||
|
download=True,
|
||||||
|
transform=ToTensor(),
|
||||||
|
)
|
||||||
|
|
||||||
|
batch_size = 64
|
||||||
|
|
||||||
|
train_dataloader = DataLoader(training_data, batch_size=batch_size, shuffle=True, num_workers=4)
|
||||||
|
test_dataloader = DataLoader(test_data, batch_size=batch_size, num_workers=4)
|
||||||
|
|
||||||
|
for input_data, labels in test_dataloader:
|
||||||
|
print(f"Shape of input_data [N, C, H, W]: {input_data.shape}")
|
||||||
|
print(f"Shape of labels: {labels.shape} (type: {labels.dtype})")
|
||||||
|
break
|
||||||
|
|
||||||
|
device = (
|
||||||
|
"cuda"
|
||||||
|
if torch.cuda.is_available()
|
||||||
|
else "mps"
|
||||||
|
if torch.backends.mps.is_available()
|
||||||
|
else "cpu"
|
||||||
|
)
|
||||||
|
|
||||||
|
# device = "cpu"
|
||||||
|
print(f"Using {device} device")
|
||||||
|
|
||||||
|
|
||||||
|
class ConvNetwork(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.flatten = nn.Flatten()
|
||||||
|
# self.conv_stack = nn.Sequential(
|
||||||
|
# nn.Conv2d(1, 1, 3, 1, padding=0),
|
||||||
|
# nn.GELU(),
|
||||||
|
# # nn.MaxPool2d(2, 1),
|
||||||
|
# # nn.GELU(),
|
||||||
|
# nn.Conv2d(1, 1, 3, 1, padding=0),
|
||||||
|
# nn.GELU(),
|
||||||
|
# # nn.Conv2d(1, 1, 3, 1, padding=0, device=device),
|
||||||
|
# # nn.GELU(),
|
||||||
|
# # nn.Conv2d(1, 1, 3, 1, padding=0, device=device),
|
||||||
|
# # nn.GELU(),
|
||||||
|
# # nn.Conv2d(1, 1, 3, 1, padding=0, device=device),
|
||||||
|
# # nn.GELU(),
|
||||||
|
# # nn.Conv2d(1, 1, 3, 1, padding=0, device=device),
|
||||||
|
# # nn.GELU(),
|
||||||
|
# # nn.MaxPool2d(),
|
||||||
|
# )
|
||||||
|
#
|
||||||
|
# self.linear_stack = nn.Sequential(
|
||||||
|
# nn.Linear(24 * 24, 64),
|
||||||
|
# nn.GELU(),
|
||||||
|
# # nn.Dropout(0.1),
|
||||||
|
# nn.Linear(64, 32),
|
||||||
|
# nn.GELU(),
|
||||||
|
# # nn.Linear(256, 128),
|
||||||
|
# # nn.GELU(),
|
||||||
|
# # nn.Dropout(0.2),
|
||||||
|
# # nn.Linear(128, 64),
|
||||||
|
# # nn.GELU(),
|
||||||
|
# # nn.Dropout(0.2),
|
||||||
|
# nn.Linear(32, 10),
|
||||||
|
# # nn.GELU(),
|
||||||
|
# # nn.Linear(512, 10),
|
||||||
|
# )
|
||||||
|
|
||||||
|
self.conv_stack = nn.Sequential(
|
||||||
|
nn.Conv2d(1, 32, 3, 1),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Conv2d(32, 64, 3, 1),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.MaxPool2d(2),
|
||||||
|
nn.Conv2d(64, 128, 3, 1),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Conv2d(128, 128, 3, 1),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.MaxPool2d(2)
|
||||||
|
)
|
||||||
|
self.linear_stack = nn.Sequential(
|
||||||
|
nn.Linear(128 * 4 * 4, 128),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Linear(128, 10)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.conv_stack(x)
|
||||||
|
# x = self.flatten(x)
|
||||||
|
x = x.view(x.size(0), -1)
|
||||||
|
logits = self.linear_stack(x)
|
||||||
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
feed_forward_model = ConvNetwork().to(device)
|
||||||
|
print(feed_forward_model)
|
||||||
|
|
||||||
|
loss_fn = nn.MSELoss()
|
||||||
|
loss_fn = nn.CrossEntropyLoss()
|
||||||
|
sgd_optimizer = torch.optim.SGD(feed_forward_model.parameters(), lr=1e-3, momentum=0.9)
|
||||||
|
adam = torch.optim.AdamW(feed_forward_model.parameters(), lr=1e-3)
|
||||||
|
|
||||||
|
def train(dataloader, model, loss_fn, optimizer):
|
||||||
|
size = len(dataloader.dataset)
|
||||||
|
model.train()
|
||||||
|
for batch, (X, y) in enumerate(dataloader):
|
||||||
|
X, y = X.to(device), y.to(device)
|
||||||
|
|
||||||
|
# with profile(activities=[
|
||||||
|
# ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof:
|
||||||
|
# with record_function("model_inference"):
|
||||||
|
pred = model(X)
|
||||||
|
# print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
|
||||||
|
|
||||||
|
loss = loss_fn(pred, y) #loss_fn(pred, F.one_hot(y, num_classes=10).float())
|
||||||
|
|
||||||
|
# Backpropagation
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
if batch % 100 == 0:
|
||||||
|
loss = loss.item()
|
||||||
|
current = (batch + 1) * len(X)
|
||||||
|
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
|
||||||
|
|
||||||
|
|
||||||
|
def test(dataloader, model, loss_fn):
|
||||||
|
size = len(dataloader.dataset)
|
||||||
|
num_batches = len(dataloader)
|
||||||
|
model.eval()
|
||||||
|
test_loss, correct = 0, 0
|
||||||
|
misclassified_images = []
|
||||||
|
with torch.no_grad():
|
||||||
|
for X, y in dataloader:
|
||||||
|
X, y = X.to(device), y.to(device)
|
||||||
|
|
||||||
|
pred = model(X)
|
||||||
|
|
||||||
|
test_loss += loss_fn(pred, y).item()
|
||||||
|
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
|
||||||
|
# Collect misclassified images
|
||||||
|
misclassified_images.extend(
|
||||||
|
[(img, true_label, pred_label) for img, true_label, pred_label in zip(X, y, pred.argmax(1)) if
|
||||||
|
true_label != pred_label])
|
||||||
|
test_loss /= num_batches
|
||||||
|
correct /= size
|
||||||
|
print(f"Test Error: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
|
||||||
|
return misclassified_images
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
epochs = 20
|
||||||
|
start_time = time.time()
|
||||||
|
for t in range(epochs):
|
||||||
|
print(
|
||||||
|
f"Epoch {t + 1} (lr: {sgd_optimizer.state_dict()['param_groups'][0]['lr']})\n-------------------------------")
|
||||||
|
train(train_dataloader,
|
||||||
|
feed_forward_model,
|
||||||
|
loss_fn,
|
||||||
|
adam)
|
||||||
|
test(test_dataloader,
|
||||||
|
feed_forward_model,
|
||||||
|
loss_fn)
|
||||||
|
print("Done!")
|
||||||
|
end_time = time.time()
|
||||||
|
duration = end_time - start_time
|
||||||
|
print(f"Training took {duration:.2f} seconds")
|
||||||
|
|
||||||
|
misclassified_images = test(test_dataloader, feed_forward_model, loss_fn)
|
||||||
|
print(len(misclassified_images))
|
||||||
|
|
||||||
|
# # Visualizing misclassified images
|
||||||
|
# plt.figure(figsize=(10, 10))
|
||||||
|
# for i, (image, true_label, pred_label) in enumerate(misclassified_images):
|
||||||
|
# # plt.subplot(5, 5, i + 1)
|
||||||
|
# plt.imshow(image.cpu().squeeze(), cmap='gray')
|
||||||
|
# plt.title(f'True: {true_label.item()}\nPredicted: {pred_label.item()}')
|
||||||
|
# plt.axis('off')
|
||||||
|
# plt.show()
|
||||||
BIN
data/MNIST/raw/t10k-images-idx3-ubyte
Normal file
BIN
data/MNIST/raw/t10k-images-idx3-ubyte
Normal file
Binary file not shown.
BIN
data/MNIST/raw/t10k-images-idx3-ubyte.gz
Normal file
BIN
data/MNIST/raw/t10k-images-idx3-ubyte.gz
Normal file
Binary file not shown.
BIN
data/MNIST/raw/t10k-labels-idx1-ubyte
Normal file
BIN
data/MNIST/raw/t10k-labels-idx1-ubyte
Normal file
Binary file not shown.
BIN
data/MNIST/raw/t10k-labels-idx1-ubyte.gz
Normal file
BIN
data/MNIST/raw/t10k-labels-idx1-ubyte.gz
Normal file
Binary file not shown.
BIN
data/MNIST/raw/train-images-idx3-ubyte
Normal file
BIN
data/MNIST/raw/train-images-idx3-ubyte
Normal file
Binary file not shown.
BIN
data/MNIST/raw/train-images-idx3-ubyte.gz
Normal file
BIN
data/MNIST/raw/train-images-idx3-ubyte.gz
Normal file
Binary file not shown.
BIN
data/MNIST/raw/train-labels-idx1-ubyte
Normal file
BIN
data/MNIST/raw/train-labels-idx1-ubyte
Normal file
Binary file not shown.
BIN
data/MNIST/raw/train-labels-idx1-ubyte.gz
Normal file
BIN
data/MNIST/raw/train-labels-idx1-ubyte.gz
Normal file
Binary file not shown.
Reference in New Issue
Block a user