accidentals/train.py
dullfig afc16c2bbb Initial commit: accidental classifier (sharp/flat/natural)
CNN trained on PrIMuS crops achieves 100% on held-out test set.
Includes training pipeline, evaluation script, extraction tools,
and saved model weights.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-02 08:03:56 -08:00

315 lines
10 KiB
Python

"""Train a small CNN to classify accidentals (sharp/flat/natural).
Outputs an ONNX model for inference via OpenCV's dnn module.
"""
import os
import sys
import time
import random
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from PIL import Image
from torchvision import transforms
sys.stdout.reconfigure(encoding="utf-8")
CROPS_ROOT = Path(r"C:\src\accidentals\crops")
OUTPUT_DIR = Path(r"C:\src\accidentals\model")
LABELS = ["flat", "natural", "sharp"] # alphabetical → class indices 0,1,2
IMG_SIZE = (40, 40) # square, small enough for fast inference
BATCH_SIZE = 256
NUM_EPOCHS = 30
LR = 1e-3
SEED = 42
VAL_SPLIT = 0.1
TEST_SPLIT = 0.05
MAX_PER_CLASS = 5000 # subsample: more than enough for this easy problem
class InMemoryDataset(Dataset):
"""Dataset that holds all images as tensors in RAM."""
def __init__(self, images: torch.Tensor, labels: torch.Tensor, augment=False):
self.images = images # (N, 1, H, W) float32
self.labels = labels # (N,) long
self.augment = augment
self.affine = transforms.RandomAffine(
degrees=5, translate=(0.1, 0.1), scale=(0.9, 1.1)
)
self.erasing = transforms.RandomErasing(p=0.2, scale=(0.02, 0.08))
def __len__(self):
return len(self.labels)
def __getitem__(self, idx):
img = self.images[idx]
if self.augment:
img = self.affine(img)
img = self.erasing(img)
return img, self.labels[idx]
class AccidentalCNN(nn.Module):
"""Small CNN for 3-class accidental classification.
Input: 1x40x40 grayscale image
Architecture: 3 conv blocks + 1 FC layer
"""
def __init__(self):
super().__init__()
self.features = nn.Sequential(
# Block 1: 1 -> 16 channels, 40x40 -> 20x20
nn.Conv2d(1, 16, 3, padding=1),
nn.BatchNorm2d(16),
nn.ReLU(inplace=True),
nn.MaxPool2d(2),
# Block 2: 16 -> 32 channels, 20x20 -> 10x10
nn.Conv2d(16, 32, 3, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.MaxPool2d(2),
# Block 3: 32 -> 64 channels, 10x10 -> 5x5
nn.Conv2d(32, 64, 3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(2),
)
self.classifier = nn.Sequential(
nn.Dropout(0.3),
nn.Linear(64 * 5 * 5, 64),
nn.ReLU(inplace=True),
nn.Dropout(0.3),
nn.Linear(64, len(LABELS)),
)
def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x
def load_all_into_ram():
"""Load all images into RAM as tensors. Returns (images, labels) tensors."""
print("Loading all images into RAM...", flush=True)
all_images = []
all_labels = []
for label_idx, label_name in enumerate(LABELS):
label_dir = CROPS_ROOT / label_name
files = [f for f in os.listdir(label_dir) if f.endswith(".png")]
random.shuffle(files)
if MAX_PER_CLASS and len(files) > MAX_PER_CLASS:
files = files[:MAX_PER_CLASS]
print(f" {label_name}: {len(files)} files", flush=True)
for fname in files:
img = Image.open(label_dir / fname).convert("L")
img = img.resize(IMG_SIZE, Image.BILINEAR)
arr = np.array(img, dtype=np.float32) / 255.0
all_images.append(arr)
all_labels.append(label_idx)
# Stack into tensors: (N, 1, H, W)
images = torch.tensor(np.array(all_images)).unsqueeze(1)
labels = torch.tensor(all_labels, dtype=torch.long)
print(f" Total: {len(labels)} images, tensor shape: {images.shape}", flush=True)
print(f" Memory: {images.nbytes / 1024**2:.0f} MB", flush=True)
return images, labels
def split_data(images, labels):
"""Stratified train/val/test split. Returns index arrays."""
indices = torch.randperm(len(labels))
by_class = {i: [] for i in range(len(LABELS))}
for idx in indices:
by_class[labels[idx].item()].append(idx.item())
train_idx, val_idx, test_idx = [], [], []
for cls, idxs in by_class.items():
n = len(idxs)
n_test = max(1, int(n * TEST_SPLIT))
n_val = max(1, int(n * VAL_SPLIT))
test_idx.extend(idxs[:n_test])
val_idx.extend(idxs[n_test : n_test + n_val])
train_idx.extend(idxs[n_test + n_val :])
return (
torch.tensor(train_idx),
torch.tensor(val_idx),
torch.tensor(test_idx),
)
def make_weighted_sampler(labels):
"""Create a sampler that oversamples minority classes."""
class_counts = torch.bincount(labels, minlength=len(LABELS)).float()
weights = 1.0 / class_counts[labels]
return WeightedRandomSampler(weights, len(labels), replacement=True)
def evaluate(model, loader, device):
"""Evaluate model on a data loader."""
model.eval()
correct = 0
total = 0
class_correct = np.zeros(len(LABELS))
class_total = np.zeros(len(LABELS))
total_loss = 0.0
n_batches = 0
with torch.no_grad():
for imgs, lbls in loader:
imgs, lbls = imgs.to(device), lbls.to(device)
outputs = model(imgs)
loss = F.cross_entropy(outputs, lbls)
total_loss += loss.item()
n_batches += 1
_, predicted = outputs.max(1)
total += lbls.size(0)
correct += predicted.eq(lbls).sum().item()
for i in range(len(LABELS)):
mask = lbls == i
class_total[i] += mask.sum().item()
class_correct[i] += (predicted[mask] == i).sum().item()
acc = correct / total if total > 0 else 0
class_acc = np.where(class_total > 0, class_correct / class_total, 0)
avg_loss = total_loss / n_batches if n_batches > 0 else 0
return avg_loss, acc, class_acc
def main():
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
device = torch.device("cpu")
print(f"Device: {device}", flush=True)
# Load everything into RAM
images, labels = load_all_into_ram()
# Split
train_idx, val_idx, test_idx = split_data(images, labels)
print(f"\nTrain: {len(train_idx)} Val: {len(val_idx)} Test: {len(test_idx)}", flush=True)
for i, name in enumerate(LABELS):
nt = (labels[train_idx] == i).sum().item()
nv = (labels[val_idx] == i).sum().item()
ne = (labels[test_idx] == i).sum().item()
print(f" {name:8s}: train={nt:6d} val={nv:5d} test={ne:5d}", flush=True)
# Create datasets (in-memory, fast)
train_ds = InMemoryDataset(images[train_idx], labels[train_idx], augment=True)
val_ds = InMemoryDataset(images[val_idx], labels[val_idx], augment=False)
test_ds = InMemoryDataset(images[test_idx], labels[test_idx], augment=False)
# Weighted sampler for class imbalance
sampler = make_weighted_sampler(labels[train_idx])
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, sampler=sampler, num_workers=0)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
# Model
model = AccidentalCNN().to(device)
param_count = sum(p.numel() for p in model.parameters())
print(f"\nModel parameters: {param_count:,}", flush=True)
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode="max", factor=0.5, patience=3
)
# Training loop
best_val_acc = 0.0
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
best_model_path = OUTPUT_DIR / "accidental_classifier.pt"
print(f"\n{'Epoch':>5s} {'TrLoss':>7s} {'VaLoss':>7s} {'VaAcc':>6s} "
f"{'Flat':>5s} {'Nat':>5s} {'Sharp':>5s} {'LR':>8s} {'Time':>5s}", flush=True)
print("-" * 70, flush=True)
for epoch in range(1, NUM_EPOCHS + 1):
t0 = time.time()
model.train()
train_loss = 0.0
n_batches = 0
for imgs, lbls in train_loader:
imgs, lbls = imgs.to(device), lbls.to(device)
optimizer.zero_grad()
outputs = model(imgs)
loss = F.cross_entropy(outputs, lbls)
loss.backward()
optimizer.step()
train_loss += loss.item()
n_batches += 1
avg_train_loss = train_loss / n_batches
val_loss, val_acc, class_acc = evaluate(model, val_loader, device)
scheduler.step(val_acc)
elapsed = time.time() - t0
lr = optimizer.param_groups[0]["lr"]
print(
f"{epoch:5d} {avg_train_loss:7.4f} {val_loss:7.4f} {val_acc:6.1%} "
f"{class_acc[0]:5.1%} {class_acc[1]:5.1%} {class_acc[2]:5.1%} "
f"{lr:.6f} {elapsed:5.1f}s",
flush=True,
)
if val_acc > best_val_acc:
best_val_acc = val_acc
torch.save(model.state_dict(), best_model_path)
# Final evaluation on test set
print(f"\n=== Test Set Evaluation ===", flush=True)
model.load_state_dict(torch.load(best_model_path, weights_only=True))
test_loss, test_acc, test_class_acc = evaluate(model, test_loader, device)
print(f"Test accuracy: {test_acc:.1%}", flush=True)
for i, name in enumerate(LABELS):
print(f" {name:8s}: {test_class_acc[i]:.1%}", flush=True)
# Export to ONNX for OpenCV dnn inference
print(f"\nExporting to ONNX...", flush=True)
model.eval()
dummy = torch.randn(1, 1, *IMG_SIZE)
onnx_path = OUTPUT_DIR / "accidental_classifier.onnx"
torch.onnx.export(
model,
dummy,
str(onnx_path),
input_names=["image"],
output_names=["logits"],
dynamic_axes={"image": {0: "batch"}, "logits": {0: "batch"}},
opset_version=13,
)
print(f"Saved ONNX model: {onnx_path}", flush=True)
print(f"Saved PyTorch model: {best_model_path}", flush=True)
# Quick ONNX verification via OpenCV
import cv2
net = cv2.dnn.readNetFromONNX(str(onnx_path))
test_img = np.random.randint(0, 255, (IMG_SIZE[1], IMG_SIZE[0]), dtype=np.uint8)
blob = cv2.dnn.blobFromImage(test_img, 1.0 / 255.0, IMG_SIZE, 0, swapRB=False)
net.setInput(blob)
out = net.forward()
pred = LABELS[np.argmax(out[0])]
print(f"ONNX+OpenCV smoke test: input=random noise -> pred={pred} (expected: arbitrary)", flush=True)
print(f"\nDone. Labels: {LABELS}", flush=True)
if __name__ == "__main__":
main()