CNN trained on PrIMuS crops achieves 100% on held-out test set. Includes training pipeline, evaluation script, extraction tools, and saved model weights. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
148 lines
4.4 KiB
Python
148 lines
4.4 KiB
Python
"""Evaluate the saved accidental classifier on the held-out test set.
|
|
|
|
Uses the same seed (42), split ratios, and data loading as train.py
|
|
so the test set is identical to what training used.
|
|
"""
|
|
|
|
import os
|
|
import random
|
|
from pathlib import Path
|
|
from collections import defaultdict
|
|
|
|
import numpy as np
|
|
import torch
|
|
from PIL import Image
|
|
|
|
from train import (
|
|
AccidentalCNN,
|
|
CROPS_ROOT,
|
|
OUTPUT_DIR,
|
|
LABELS,
|
|
IMG_SIZE,
|
|
SEED,
|
|
VAL_SPLIT,
|
|
TEST_SPLIT,
|
|
MAX_PER_CLASS,
|
|
InMemoryDataset,
|
|
load_all_into_ram,
|
|
split_data,
|
|
)
|
|
from torch.utils.data import DataLoader
|
|
|
|
MODEL_PATH = OUTPUT_DIR / "accidental_classifier.pt"
|
|
|
|
|
|
def collect_file_paths():
|
|
"""Collect file paths in the same order as load_all_into_ram.
|
|
|
|
Must be called after seeding so random.shuffle matches train.py.
|
|
"""
|
|
paths = []
|
|
for label_idx, label_name in enumerate(LABELS):
|
|
label_dir = CROPS_ROOT / label_name
|
|
files = [f for f in os.listdir(label_dir) if f.endswith(".png")]
|
|
random.shuffle(files)
|
|
if MAX_PER_CLASS and len(files) > MAX_PER_CLASS:
|
|
files = files[:MAX_PER_CLASS]
|
|
for fname in files:
|
|
paths.append(str(label_dir / fname))
|
|
return paths
|
|
|
|
|
|
def main():
|
|
# Reproduce the exact same RNG state and splits as train.py
|
|
random.seed(SEED)
|
|
np.random.seed(SEED)
|
|
torch.manual_seed(SEED)
|
|
|
|
# Collect file paths (consumes same random state as load_all_into_ram)
|
|
file_paths = collect_file_paths()
|
|
|
|
# Re-seed and reload to get tensors (load_all_into_ram shuffles again internally)
|
|
random.seed(SEED)
|
|
np.random.seed(SEED)
|
|
torch.manual_seed(SEED)
|
|
|
|
images, labels = load_all_into_ram()
|
|
_, _, test_idx = split_data(images, labels)
|
|
|
|
print(f"\nTest set size: {len(test_idx)}")
|
|
for i, name in enumerate(LABELS):
|
|
n = (labels[test_idx] == i).sum().item()
|
|
print(f" {name:8s}: {n}")
|
|
|
|
# Load model
|
|
device = torch.device("cpu")
|
|
model = AccidentalCNN().to(device)
|
|
model.load_state_dict(torch.load(MODEL_PATH, weights_only=True))
|
|
model.eval()
|
|
print(f"\nLoaded weights from {MODEL_PATH}")
|
|
|
|
# Run inference
|
|
test_ds = InMemoryDataset(images[test_idx], labels[test_idx], augment=False)
|
|
test_loader = DataLoader(test_ds, batch_size=256, shuffle=False, num_workers=0)
|
|
|
|
all_preds = []
|
|
all_labels = []
|
|
|
|
with torch.no_grad():
|
|
for imgs, lbls in test_loader:
|
|
imgs = imgs.to(device)
|
|
outputs = model(imgs)
|
|
_, predicted = outputs.max(1)
|
|
all_preds.append(predicted.cpu())
|
|
all_labels.append(lbls)
|
|
|
|
all_preds = torch.cat(all_preds)
|
|
all_labels = torch.cat(all_labels)
|
|
|
|
# Overall accuracy
|
|
correct = (all_preds == all_labels).sum().item()
|
|
total = len(all_labels)
|
|
print(f"\n=== Test Results ===")
|
|
print(f"Overall accuracy: {correct}/{total} ({correct/total:.1%})")
|
|
|
|
# Per-class accuracy
|
|
print(f"\nPer-class accuracy:")
|
|
for i, name in enumerate(LABELS):
|
|
mask = all_labels == i
|
|
cls_total = mask.sum().item()
|
|
cls_correct = (all_preds[mask] == i).sum().item()
|
|
acc = cls_correct / cls_total if cls_total > 0 else 0
|
|
print(f" {name:8s}: {cls_correct}/{cls_total} ({acc:.1%})")
|
|
|
|
# Confusion matrix
|
|
n_classes = len(LABELS)
|
|
cm = torch.zeros(n_classes, n_classes, dtype=torch.long)
|
|
for t, p in zip(all_labels, all_preds):
|
|
cm[t][p] += 1
|
|
|
|
print(f"\nConfusion matrix (rows=true, cols=predicted):")
|
|
header = " " + "".join(f"{name:>9s}" for name in LABELS)
|
|
print(header)
|
|
for i, name in enumerate(LABELS):
|
|
row = f" {name:8s}" + "".join(f"{cm[i][j].item():9d}" for j in range(n_classes))
|
|
print(row)
|
|
|
|
# Sample misclassifications
|
|
test_indices = test_idx.tolist()
|
|
misclassified = []
|
|
for i in range(len(all_preds)):
|
|
if all_preds[i] != all_labels[i]:
|
|
orig_idx = test_indices[i]
|
|
misclassified.append((
|
|
file_paths[orig_idx],
|
|
LABELS[all_labels[i].item()],
|
|
LABELS[all_preds[i].item()],
|
|
))
|
|
|
|
print(f"\nMisclassifications: {len(misclassified)}/{total}")
|
|
if misclassified:
|
|
print(f"\nSample misclassifications (up to 10):")
|
|
for path, true_lbl, pred_lbl in misclassified[:10]:
|
|
print(f" {path}")
|
|
print(f" true={true_lbl} predicted={pred_lbl}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|