"""Evaluate the saved accidental classifier on the held-out test set. Uses the same seed (42), split ratios, and data loading as train.py so the test set is identical to what training used. """ import os import random from pathlib import Path from collections import defaultdict import numpy as np import torch from PIL import Image from train import ( AccidentalCNN, CROPS_ROOT, OUTPUT_DIR, LABELS, IMG_SIZE, SEED, VAL_SPLIT, TEST_SPLIT, MAX_PER_CLASS, InMemoryDataset, load_all_into_ram, split_data, ) from torch.utils.data import DataLoader MODEL_PATH = OUTPUT_DIR / "accidental_classifier.pt" def collect_file_paths(): """Collect file paths in the same order as load_all_into_ram. Must be called after seeding so random.shuffle matches train.py. """ paths = [] for label_idx, label_name in enumerate(LABELS): label_dir = CROPS_ROOT / label_name files = [f for f in os.listdir(label_dir) if f.endswith(".png")] random.shuffle(files) if MAX_PER_CLASS and len(files) > MAX_PER_CLASS: files = files[:MAX_PER_CLASS] for fname in files: paths.append(str(label_dir / fname)) return paths def main(): # Reproduce the exact same RNG state and splits as train.py random.seed(SEED) np.random.seed(SEED) torch.manual_seed(SEED) # Collect file paths (consumes same random state as load_all_into_ram) file_paths = collect_file_paths() # Re-seed and reload to get tensors (load_all_into_ram shuffles again internally) random.seed(SEED) np.random.seed(SEED) torch.manual_seed(SEED) images, labels = load_all_into_ram() _, _, test_idx = split_data(images, labels) print(f"\nTest set size: {len(test_idx)}") for i, name in enumerate(LABELS): n = (labels[test_idx] == i).sum().item() print(f" {name:8s}: {n}") # Load model device = torch.device("cpu") model = AccidentalCNN().to(device) model.load_state_dict(torch.load(MODEL_PATH, weights_only=True)) model.eval() print(f"\nLoaded weights from {MODEL_PATH}") # Run inference test_ds = InMemoryDataset(images[test_idx], labels[test_idx], augment=False) test_loader = DataLoader(test_ds, batch_size=256, shuffle=False, num_workers=0) all_preds = [] all_labels = [] with torch.no_grad(): for imgs, lbls in test_loader: imgs = imgs.to(device) outputs = model(imgs) _, predicted = outputs.max(1) all_preds.append(predicted.cpu()) all_labels.append(lbls) all_preds = torch.cat(all_preds) all_labels = torch.cat(all_labels) # Overall accuracy correct = (all_preds == all_labels).sum().item() total = len(all_labels) print(f"\n=== Test Results ===") print(f"Overall accuracy: {correct}/{total} ({correct/total:.1%})") # Per-class accuracy print(f"\nPer-class accuracy:") for i, name in enumerate(LABELS): mask = all_labels == i cls_total = mask.sum().item() cls_correct = (all_preds[mask] == i).sum().item() acc = cls_correct / cls_total if cls_total > 0 else 0 print(f" {name:8s}: {cls_correct}/{cls_total} ({acc:.1%})") # Confusion matrix n_classes = len(LABELS) cm = torch.zeros(n_classes, n_classes, dtype=torch.long) for t, p in zip(all_labels, all_preds): cm[t][p] += 1 print(f"\nConfusion matrix (rows=true, cols=predicted):") header = " " + "".join(f"{name:>9s}" for name in LABELS) print(header) for i, name in enumerate(LABELS): row = f" {name:8s}" + "".join(f"{cm[i][j].item():9d}" for j in range(n_classes)) print(row) # Sample misclassifications test_indices = test_idx.tolist() misclassified = [] for i in range(len(all_preds)): if all_preds[i] != all_labels[i]: orig_idx = test_indices[i] misclassified.append(( file_paths[orig_idx], LABELS[all_labels[i].item()], LABELS[all_preds[i].item()], )) print(f"\nMisclassifications: {len(misclassified)}/{total}") if misclassified: print(f"\nSample misclassifications (up to 10):") for path, true_lbl, pred_lbl in misclassified[:10]: print(f" {path}") print(f" true={true_lbl} predicted={pred_lbl}") if __name__ == "__main__": main()