CNN trained on PrIMuS crops achieves 100% on held-out test set. Includes training pipeline, evaluation script, extraction tools, and saved model weights. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
315 lines
10 KiB
Python
315 lines
10 KiB
Python
"""Train a small CNN to classify accidentals (sharp/flat/natural).
|
|
|
|
Outputs an ONNX model for inference via OpenCV's dnn module.
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
import time
|
|
import random
|
|
from pathlib import Path
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
|
|
from PIL import Image
|
|
from torchvision import transforms
|
|
|
|
sys.stdout.reconfigure(encoding="utf-8")
|
|
|
|
CROPS_ROOT = Path(r"C:\src\accidentals\crops")
|
|
OUTPUT_DIR = Path(r"C:\src\accidentals\model")
|
|
LABELS = ["flat", "natural", "sharp"] # alphabetical → class indices 0,1,2
|
|
IMG_SIZE = (40, 40) # square, small enough for fast inference
|
|
BATCH_SIZE = 256
|
|
NUM_EPOCHS = 30
|
|
LR = 1e-3
|
|
SEED = 42
|
|
VAL_SPLIT = 0.1
|
|
TEST_SPLIT = 0.05
|
|
MAX_PER_CLASS = 5000 # subsample: more than enough for this easy problem
|
|
|
|
|
|
class InMemoryDataset(Dataset):
|
|
"""Dataset that holds all images as tensors in RAM."""
|
|
|
|
def __init__(self, images: torch.Tensor, labels: torch.Tensor, augment=False):
|
|
self.images = images # (N, 1, H, W) float32
|
|
self.labels = labels # (N,) long
|
|
self.augment = augment
|
|
self.affine = transforms.RandomAffine(
|
|
degrees=5, translate=(0.1, 0.1), scale=(0.9, 1.1)
|
|
)
|
|
self.erasing = transforms.RandomErasing(p=0.2, scale=(0.02, 0.08))
|
|
|
|
def __len__(self):
|
|
return len(self.labels)
|
|
|
|
def __getitem__(self, idx):
|
|
img = self.images[idx]
|
|
if self.augment:
|
|
img = self.affine(img)
|
|
img = self.erasing(img)
|
|
return img, self.labels[idx]
|
|
|
|
|
|
class AccidentalCNN(nn.Module):
|
|
"""Small CNN for 3-class accidental classification.
|
|
|
|
Input: 1x40x40 grayscale image
|
|
Architecture: 3 conv blocks + 1 FC layer
|
|
"""
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.features = nn.Sequential(
|
|
# Block 1: 1 -> 16 channels, 40x40 -> 20x20
|
|
nn.Conv2d(1, 16, 3, padding=1),
|
|
nn.BatchNorm2d(16),
|
|
nn.ReLU(inplace=True),
|
|
nn.MaxPool2d(2),
|
|
# Block 2: 16 -> 32 channels, 20x20 -> 10x10
|
|
nn.Conv2d(16, 32, 3, padding=1),
|
|
nn.BatchNorm2d(32),
|
|
nn.ReLU(inplace=True),
|
|
nn.MaxPool2d(2),
|
|
# Block 3: 32 -> 64 channels, 10x10 -> 5x5
|
|
nn.Conv2d(32, 64, 3, padding=1),
|
|
nn.BatchNorm2d(64),
|
|
nn.ReLU(inplace=True),
|
|
nn.MaxPool2d(2),
|
|
)
|
|
self.classifier = nn.Sequential(
|
|
nn.Dropout(0.3),
|
|
nn.Linear(64 * 5 * 5, 64),
|
|
nn.ReLU(inplace=True),
|
|
nn.Dropout(0.3),
|
|
nn.Linear(64, len(LABELS)),
|
|
)
|
|
|
|
def forward(self, x):
|
|
x = self.features(x)
|
|
x = x.view(x.size(0), -1)
|
|
x = self.classifier(x)
|
|
return x
|
|
|
|
|
|
def load_all_into_ram():
|
|
"""Load all images into RAM as tensors. Returns (images, labels) tensors."""
|
|
print("Loading all images into RAM...", flush=True)
|
|
all_images = []
|
|
all_labels = []
|
|
|
|
for label_idx, label_name in enumerate(LABELS):
|
|
label_dir = CROPS_ROOT / label_name
|
|
files = [f for f in os.listdir(label_dir) if f.endswith(".png")]
|
|
random.shuffle(files)
|
|
if MAX_PER_CLASS and len(files) > MAX_PER_CLASS:
|
|
files = files[:MAX_PER_CLASS]
|
|
print(f" {label_name}: {len(files)} files", flush=True)
|
|
|
|
for fname in files:
|
|
img = Image.open(label_dir / fname).convert("L")
|
|
img = img.resize(IMG_SIZE, Image.BILINEAR)
|
|
arr = np.array(img, dtype=np.float32) / 255.0
|
|
all_images.append(arr)
|
|
all_labels.append(label_idx)
|
|
|
|
# Stack into tensors: (N, 1, H, W)
|
|
images = torch.tensor(np.array(all_images)).unsqueeze(1)
|
|
labels = torch.tensor(all_labels, dtype=torch.long)
|
|
print(f" Total: {len(labels)} images, tensor shape: {images.shape}", flush=True)
|
|
print(f" Memory: {images.nbytes / 1024**2:.0f} MB", flush=True)
|
|
return images, labels
|
|
|
|
|
|
def split_data(images, labels):
|
|
"""Stratified train/val/test split. Returns index arrays."""
|
|
indices = torch.randperm(len(labels))
|
|
|
|
by_class = {i: [] for i in range(len(LABELS))}
|
|
for idx in indices:
|
|
by_class[labels[idx].item()].append(idx.item())
|
|
|
|
train_idx, val_idx, test_idx = [], [], []
|
|
for cls, idxs in by_class.items():
|
|
n = len(idxs)
|
|
n_test = max(1, int(n * TEST_SPLIT))
|
|
n_val = max(1, int(n * VAL_SPLIT))
|
|
test_idx.extend(idxs[:n_test])
|
|
val_idx.extend(idxs[n_test : n_test + n_val])
|
|
train_idx.extend(idxs[n_test + n_val :])
|
|
|
|
return (
|
|
torch.tensor(train_idx),
|
|
torch.tensor(val_idx),
|
|
torch.tensor(test_idx),
|
|
)
|
|
|
|
|
|
def make_weighted_sampler(labels):
|
|
"""Create a sampler that oversamples minority classes."""
|
|
class_counts = torch.bincount(labels, minlength=len(LABELS)).float()
|
|
weights = 1.0 / class_counts[labels]
|
|
return WeightedRandomSampler(weights, len(labels), replacement=True)
|
|
|
|
|
|
def evaluate(model, loader, device):
|
|
"""Evaluate model on a data loader."""
|
|
model.eval()
|
|
correct = 0
|
|
total = 0
|
|
class_correct = np.zeros(len(LABELS))
|
|
class_total = np.zeros(len(LABELS))
|
|
total_loss = 0.0
|
|
n_batches = 0
|
|
|
|
with torch.no_grad():
|
|
for imgs, lbls in loader:
|
|
imgs, lbls = imgs.to(device), lbls.to(device)
|
|
outputs = model(imgs)
|
|
loss = F.cross_entropy(outputs, lbls)
|
|
total_loss += loss.item()
|
|
n_batches += 1
|
|
|
|
_, predicted = outputs.max(1)
|
|
total += lbls.size(0)
|
|
correct += predicted.eq(lbls).sum().item()
|
|
|
|
for i in range(len(LABELS)):
|
|
mask = lbls == i
|
|
class_total[i] += mask.sum().item()
|
|
class_correct[i] += (predicted[mask] == i).sum().item()
|
|
|
|
acc = correct / total if total > 0 else 0
|
|
class_acc = np.where(class_total > 0, class_correct / class_total, 0)
|
|
avg_loss = total_loss / n_batches if n_batches > 0 else 0
|
|
return avg_loss, acc, class_acc
|
|
|
|
|
|
def main():
|
|
random.seed(SEED)
|
|
np.random.seed(SEED)
|
|
torch.manual_seed(SEED)
|
|
|
|
device = torch.device("cpu")
|
|
print(f"Device: {device}", flush=True)
|
|
|
|
# Load everything into RAM
|
|
images, labels = load_all_into_ram()
|
|
|
|
# Split
|
|
train_idx, val_idx, test_idx = split_data(images, labels)
|
|
print(f"\nTrain: {len(train_idx)} Val: {len(val_idx)} Test: {len(test_idx)}", flush=True)
|
|
for i, name in enumerate(LABELS):
|
|
nt = (labels[train_idx] == i).sum().item()
|
|
nv = (labels[val_idx] == i).sum().item()
|
|
ne = (labels[test_idx] == i).sum().item()
|
|
print(f" {name:8s}: train={nt:6d} val={nv:5d} test={ne:5d}", flush=True)
|
|
|
|
# Create datasets (in-memory, fast)
|
|
train_ds = InMemoryDataset(images[train_idx], labels[train_idx], augment=True)
|
|
val_ds = InMemoryDataset(images[val_idx], labels[val_idx], augment=False)
|
|
test_ds = InMemoryDataset(images[test_idx], labels[test_idx], augment=False)
|
|
|
|
# Weighted sampler for class imbalance
|
|
sampler = make_weighted_sampler(labels[train_idx])
|
|
|
|
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, sampler=sampler, num_workers=0)
|
|
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
|
|
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
|
|
|
|
# Model
|
|
model = AccidentalCNN().to(device)
|
|
param_count = sum(p.numel() for p in model.parameters())
|
|
print(f"\nModel parameters: {param_count:,}", flush=True)
|
|
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
|
|
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
|
optimizer, mode="max", factor=0.5, patience=3
|
|
)
|
|
|
|
# Training loop
|
|
best_val_acc = 0.0
|
|
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
|
best_model_path = OUTPUT_DIR / "accidental_classifier.pt"
|
|
|
|
print(f"\n{'Epoch':>5s} {'TrLoss':>7s} {'VaLoss':>7s} {'VaAcc':>6s} "
|
|
f"{'Flat':>5s} {'Nat':>5s} {'Sharp':>5s} {'LR':>8s} {'Time':>5s}", flush=True)
|
|
print("-" * 70, flush=True)
|
|
|
|
for epoch in range(1, NUM_EPOCHS + 1):
|
|
t0 = time.time()
|
|
model.train()
|
|
train_loss = 0.0
|
|
n_batches = 0
|
|
|
|
for imgs, lbls in train_loader:
|
|
imgs, lbls = imgs.to(device), lbls.to(device)
|
|
optimizer.zero_grad()
|
|
outputs = model(imgs)
|
|
loss = F.cross_entropy(outputs, lbls)
|
|
loss.backward()
|
|
optimizer.step()
|
|
train_loss += loss.item()
|
|
n_batches += 1
|
|
|
|
avg_train_loss = train_loss / n_batches
|
|
val_loss, val_acc, class_acc = evaluate(model, val_loader, device)
|
|
scheduler.step(val_acc)
|
|
|
|
elapsed = time.time() - t0
|
|
lr = optimizer.param_groups[0]["lr"]
|
|
|
|
print(
|
|
f"{epoch:5d} {avg_train_loss:7.4f} {val_loss:7.4f} {val_acc:6.1%} "
|
|
f"{class_acc[0]:5.1%} {class_acc[1]:5.1%} {class_acc[2]:5.1%} "
|
|
f"{lr:.6f} {elapsed:5.1f}s",
|
|
flush=True,
|
|
)
|
|
|
|
if val_acc > best_val_acc:
|
|
best_val_acc = val_acc
|
|
torch.save(model.state_dict(), best_model_path)
|
|
|
|
# Final evaluation on test set
|
|
print(f"\n=== Test Set Evaluation ===", flush=True)
|
|
model.load_state_dict(torch.load(best_model_path, weights_only=True))
|
|
test_loss, test_acc, test_class_acc = evaluate(model, test_loader, device)
|
|
print(f"Test accuracy: {test_acc:.1%}", flush=True)
|
|
for i, name in enumerate(LABELS):
|
|
print(f" {name:8s}: {test_class_acc[i]:.1%}", flush=True)
|
|
|
|
# Export to ONNX for OpenCV dnn inference
|
|
print(f"\nExporting to ONNX...", flush=True)
|
|
model.eval()
|
|
dummy = torch.randn(1, 1, *IMG_SIZE)
|
|
onnx_path = OUTPUT_DIR / "accidental_classifier.onnx"
|
|
torch.onnx.export(
|
|
model,
|
|
dummy,
|
|
str(onnx_path),
|
|
input_names=["image"],
|
|
output_names=["logits"],
|
|
dynamic_axes={"image": {0: "batch"}, "logits": {0: "batch"}},
|
|
opset_version=13,
|
|
)
|
|
print(f"Saved ONNX model: {onnx_path}", flush=True)
|
|
print(f"Saved PyTorch model: {best_model_path}", flush=True)
|
|
|
|
# Quick ONNX verification via OpenCV
|
|
import cv2
|
|
net = cv2.dnn.readNetFromONNX(str(onnx_path))
|
|
test_img = np.random.randint(0, 255, (IMG_SIZE[1], IMG_SIZE[0]), dtype=np.uint8)
|
|
blob = cv2.dnn.blobFromImage(test_img, 1.0 / 255.0, IMG_SIZE, 0, swapRB=False)
|
|
net.setInput(blob)
|
|
out = net.forward()
|
|
pred = LABELS[np.argmax(out[0])]
|
|
print(f"ONNX+OpenCV smoke test: input=random noise -> pred={pred} (expected: arbitrary)", flush=True)
|
|
print(f"\nDone. Labels: {LABELS}", flush=True)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|