accidentals/eval.py
dullfig ebc925482e Initial commit: accidental classifier (sharp/flat/natural)
CNN trained on PrIMuS crops achieves 100% on held-out test set.
Includes training pipeline, evaluation script, extraction tools,
and saved model weights.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-02 08:01:37 -08:00

148 lines
4.4 KiB
Python

"""Evaluate the saved accidental classifier on the held-out test set.
Uses the same seed (42), split ratios, and data loading as train.py
so the test set is identical to what training used.
"""
import os
import random
from pathlib import Path
from collections import defaultdict
import numpy as np
import torch
from PIL import Image
from train import (
AccidentalCNN,
CROPS_ROOT,
OUTPUT_DIR,
LABELS,
IMG_SIZE,
SEED,
VAL_SPLIT,
TEST_SPLIT,
MAX_PER_CLASS,
InMemoryDataset,
load_all_into_ram,
split_data,
)
from torch.utils.data import DataLoader
MODEL_PATH = OUTPUT_DIR / "accidental_classifier.pt"
def collect_file_paths():
"""Collect file paths in the same order as load_all_into_ram.
Must be called after seeding so random.shuffle matches train.py.
"""
paths = []
for label_idx, label_name in enumerate(LABELS):
label_dir = CROPS_ROOT / label_name
files = [f for f in os.listdir(label_dir) if f.endswith(".png")]
random.shuffle(files)
if MAX_PER_CLASS and len(files) > MAX_PER_CLASS:
files = files[:MAX_PER_CLASS]
for fname in files:
paths.append(str(label_dir / fname))
return paths
def main():
# Reproduce the exact same RNG state and splits as train.py
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
# Collect file paths (consumes same random state as load_all_into_ram)
file_paths = collect_file_paths()
# Re-seed and reload to get tensors (load_all_into_ram shuffles again internally)
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
images, labels = load_all_into_ram()
_, _, test_idx = split_data(images, labels)
print(f"\nTest set size: {len(test_idx)}")
for i, name in enumerate(LABELS):
n = (labels[test_idx] == i).sum().item()
print(f" {name:8s}: {n}")
# Load model
device = torch.device("cpu")
model = AccidentalCNN().to(device)
model.load_state_dict(torch.load(MODEL_PATH, weights_only=True))
model.eval()
print(f"\nLoaded weights from {MODEL_PATH}")
# Run inference
test_ds = InMemoryDataset(images[test_idx], labels[test_idx], augment=False)
test_loader = DataLoader(test_ds, batch_size=256, shuffle=False, num_workers=0)
all_preds = []
all_labels = []
with torch.no_grad():
for imgs, lbls in test_loader:
imgs = imgs.to(device)
outputs = model(imgs)
_, predicted = outputs.max(1)
all_preds.append(predicted.cpu())
all_labels.append(lbls)
all_preds = torch.cat(all_preds)
all_labels = torch.cat(all_labels)
# Overall accuracy
correct = (all_preds == all_labels).sum().item()
total = len(all_labels)
print(f"\n=== Test Results ===")
print(f"Overall accuracy: {correct}/{total} ({correct/total:.1%})")
# Per-class accuracy
print(f"\nPer-class accuracy:")
for i, name in enumerate(LABELS):
mask = all_labels == i
cls_total = mask.sum().item()
cls_correct = (all_preds[mask] == i).sum().item()
acc = cls_correct / cls_total if cls_total > 0 else 0
print(f" {name:8s}: {cls_correct}/{cls_total} ({acc:.1%})")
# Confusion matrix
n_classes = len(LABELS)
cm = torch.zeros(n_classes, n_classes, dtype=torch.long)
for t, p in zip(all_labels, all_preds):
cm[t][p] += 1
print(f"\nConfusion matrix (rows=true, cols=predicted):")
header = " " + "".join(f"{name:>9s}" for name in LABELS)
print(header)
for i, name in enumerate(LABELS):
row = f" {name:8s}" + "".join(f"{cm[i][j].item():9d}" for j in range(n_classes))
print(row)
# Sample misclassifications
test_indices = test_idx.tolist()
misclassified = []
for i in range(len(all_preds)):
if all_preds[i] != all_labels[i]:
orig_idx = test_indices[i]
misclassified.append((
file_paths[orig_idx],
LABELS[all_labels[i].item()],
LABELS[all_preds[i].item()],
))
print(f"\nMisclassifications: {len(misclassified)}/{total}")
if misclassified:
print(f"\nSample misclassifications (up to 10):")
for path, true_lbl, pred_lbl in misclassified[:10]:
print(f" {path}")
print(f" true={true_lbl} predicted={pred_lbl}")
if __name__ == "__main__":
main()