"""Train a small CNN to classify accidentals (sharp/flat/natural). Outputs an ONNX model for inference via OpenCV's dnn module. """ import os import sys import time import random from pathlib import Path import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler from PIL import Image from torchvision import transforms sys.stdout.reconfigure(encoding="utf-8") CROPS_ROOT = Path(r"C:\src\accidentals\crops") OUTPUT_DIR = Path(r"C:\src\accidentals\model") LABELS = ["flat", "natural", "sharp"] # alphabetical → class indices 0,1,2 IMG_SIZE = (40, 40) # square, small enough for fast inference BATCH_SIZE = 256 NUM_EPOCHS = 30 LR = 1e-3 SEED = 42 VAL_SPLIT = 0.1 TEST_SPLIT = 0.05 MAX_PER_CLASS = 5000 # subsample: more than enough for this easy problem class InMemoryDataset(Dataset): """Dataset that holds all images as tensors in RAM.""" def __init__(self, images: torch.Tensor, labels: torch.Tensor, augment=False): self.images = images # (N, 1, H, W) float32 self.labels = labels # (N,) long self.augment = augment self.affine = transforms.RandomAffine( degrees=5, translate=(0.1, 0.1), scale=(0.9, 1.1) ) self.erasing = transforms.RandomErasing(p=0.2, scale=(0.02, 0.08)) def __len__(self): return len(self.labels) def __getitem__(self, idx): img = self.images[idx] if self.augment: img = self.affine(img) img = self.erasing(img) return img, self.labels[idx] class AccidentalCNN(nn.Module): """Small CNN for 3-class accidental classification. Input: 1x40x40 grayscale image Architecture: 3 conv blocks + 1 FC layer """ def __init__(self): super().__init__() self.features = nn.Sequential( # Block 1: 1 -> 16 channels, 40x40 -> 20x20 nn.Conv2d(1, 16, 3, padding=1), nn.BatchNorm2d(16), nn.ReLU(inplace=True), nn.MaxPool2d(2), # Block 2: 16 -> 32 channels, 20x20 -> 10x10 nn.Conv2d(16, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(inplace=True), nn.MaxPool2d(2), # Block 3: 32 -> 64 channels, 10x10 -> 5x5 nn.Conv2d(32, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True), nn.MaxPool2d(2), ) self.classifier = nn.Sequential( nn.Dropout(0.3), nn.Linear(64 * 5 * 5, 64), nn.ReLU(inplace=True), nn.Dropout(0.3), nn.Linear(64, len(LABELS)), ) def forward(self, x): x = self.features(x) x = x.view(x.size(0), -1) x = self.classifier(x) return x def load_all_into_ram(): """Load all images into RAM as tensors. Returns (images, labels) tensors.""" print("Loading all images into RAM...", flush=True) all_images = [] all_labels = [] for label_idx, label_name in enumerate(LABELS): label_dir = CROPS_ROOT / label_name files = [f for f in os.listdir(label_dir) if f.endswith(".png")] random.shuffle(files) if MAX_PER_CLASS and len(files) > MAX_PER_CLASS: files = files[:MAX_PER_CLASS] print(f" {label_name}: {len(files)} files", flush=True) for fname in files: img = Image.open(label_dir / fname).convert("L") img = img.resize(IMG_SIZE, Image.BILINEAR) arr = np.array(img, dtype=np.float32) / 255.0 all_images.append(arr) all_labels.append(label_idx) # Stack into tensors: (N, 1, H, W) images = torch.tensor(np.array(all_images)).unsqueeze(1) labels = torch.tensor(all_labels, dtype=torch.long) print(f" Total: {len(labels)} images, tensor shape: {images.shape}", flush=True) print(f" Memory: {images.nbytes / 1024**2:.0f} MB", flush=True) return images, labels def split_data(images, labels): """Stratified train/val/test split. Returns index arrays.""" indices = torch.randperm(len(labels)) by_class = {i: [] for i in range(len(LABELS))} for idx in indices: by_class[labels[idx].item()].append(idx.item()) train_idx, val_idx, test_idx = [], [], [] for cls, idxs in by_class.items(): n = len(idxs) n_test = max(1, int(n * TEST_SPLIT)) n_val = max(1, int(n * VAL_SPLIT)) test_idx.extend(idxs[:n_test]) val_idx.extend(idxs[n_test : n_test + n_val]) train_idx.extend(idxs[n_test + n_val :]) return ( torch.tensor(train_idx), torch.tensor(val_idx), torch.tensor(test_idx), ) def make_weighted_sampler(labels): """Create a sampler that oversamples minority classes.""" class_counts = torch.bincount(labels, minlength=len(LABELS)).float() weights = 1.0 / class_counts[labels] return WeightedRandomSampler(weights, len(labels), replacement=True) def evaluate(model, loader, device): """Evaluate model on a data loader.""" model.eval() correct = 0 total = 0 class_correct = np.zeros(len(LABELS)) class_total = np.zeros(len(LABELS)) total_loss = 0.0 n_batches = 0 with torch.no_grad(): for imgs, lbls in loader: imgs, lbls = imgs.to(device), lbls.to(device) outputs = model(imgs) loss = F.cross_entropy(outputs, lbls) total_loss += loss.item() n_batches += 1 _, predicted = outputs.max(1) total += lbls.size(0) correct += predicted.eq(lbls).sum().item() for i in range(len(LABELS)): mask = lbls == i class_total[i] += mask.sum().item() class_correct[i] += (predicted[mask] == i).sum().item() acc = correct / total if total > 0 else 0 class_acc = np.where(class_total > 0, class_correct / class_total, 0) avg_loss = total_loss / n_batches if n_batches > 0 else 0 return avg_loss, acc, class_acc def main(): random.seed(SEED) np.random.seed(SEED) torch.manual_seed(SEED) device = torch.device("cpu") print(f"Device: {device}", flush=True) # Load everything into RAM images, labels = load_all_into_ram() # Split train_idx, val_idx, test_idx = split_data(images, labels) print(f"\nTrain: {len(train_idx)} Val: {len(val_idx)} Test: {len(test_idx)}", flush=True) for i, name in enumerate(LABELS): nt = (labels[train_idx] == i).sum().item() nv = (labels[val_idx] == i).sum().item() ne = (labels[test_idx] == i).sum().item() print(f" {name:8s}: train={nt:6d} val={nv:5d} test={ne:5d}", flush=True) # Create datasets (in-memory, fast) train_ds = InMemoryDataset(images[train_idx], labels[train_idx], augment=True) val_ds = InMemoryDataset(images[val_idx], labels[val_idx], augment=False) test_ds = InMemoryDataset(images[test_idx], labels[test_idx], augment=False) # Weighted sampler for class imbalance sampler = make_weighted_sampler(labels[train_idx]) train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, sampler=sampler, num_workers=0) val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0) test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0) # Model model = AccidentalCNN().to(device) param_count = sum(p.numel() for p in model.parameters()) print(f"\nModel parameters: {param_count:,}", flush=True) optimizer = torch.optim.Adam(model.parameters(), lr=LR) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode="max", factor=0.5, patience=3 ) # Training loop best_val_acc = 0.0 OUTPUT_DIR.mkdir(parents=True, exist_ok=True) best_model_path = OUTPUT_DIR / "accidental_classifier.pt" print(f"\n{'Epoch':>5s} {'TrLoss':>7s} {'VaLoss':>7s} {'VaAcc':>6s} " f"{'Flat':>5s} {'Nat':>5s} {'Sharp':>5s} {'LR':>8s} {'Time':>5s}", flush=True) print("-" * 70, flush=True) for epoch in range(1, NUM_EPOCHS + 1): t0 = time.time() model.train() train_loss = 0.0 n_batches = 0 for imgs, lbls in train_loader: imgs, lbls = imgs.to(device), lbls.to(device) optimizer.zero_grad() outputs = model(imgs) loss = F.cross_entropy(outputs, lbls) loss.backward() optimizer.step() train_loss += loss.item() n_batches += 1 avg_train_loss = train_loss / n_batches val_loss, val_acc, class_acc = evaluate(model, val_loader, device) scheduler.step(val_acc) elapsed = time.time() - t0 lr = optimizer.param_groups[0]["lr"] print( f"{epoch:5d} {avg_train_loss:7.4f} {val_loss:7.4f} {val_acc:6.1%} " f"{class_acc[0]:5.1%} {class_acc[1]:5.1%} {class_acc[2]:5.1%} " f"{lr:.6f} {elapsed:5.1f}s", flush=True, ) if val_acc > best_val_acc: best_val_acc = val_acc torch.save(model.state_dict(), best_model_path) # Final evaluation on test set print(f"\n=== Test Set Evaluation ===", flush=True) model.load_state_dict(torch.load(best_model_path, weights_only=True)) test_loss, test_acc, test_class_acc = evaluate(model, test_loader, device) print(f"Test accuracy: {test_acc:.1%}", flush=True) for i, name in enumerate(LABELS): print(f" {name:8s}: {test_class_acc[i]:.1%}", flush=True) # Export to ONNX for OpenCV dnn inference print(f"\nExporting to ONNX...", flush=True) model.eval() dummy = torch.randn(1, 1, *IMG_SIZE) onnx_path = OUTPUT_DIR / "accidental_classifier.onnx" torch.onnx.export( model, dummy, str(onnx_path), input_names=["image"], output_names=["logits"], dynamic_axes={"image": {0: "batch"}, "logits": {0: "batch"}}, opset_version=13, ) print(f"Saved ONNX model: {onnx_path}", flush=True) print(f"Saved PyTorch model: {best_model_path}", flush=True) # Quick ONNX verification via OpenCV import cv2 net = cv2.dnn.readNetFromONNX(str(onnx_path)) test_img = np.random.randint(0, 255, (IMG_SIZE[1], IMG_SIZE[0]), dtype=np.uint8) blob = cv2.dnn.blobFromImage(test_img, 1.0 / 255.0, IMG_SIZE, 0, swapRB=False) net.setInput(blob) out = net.forward() pred = LABELS[np.argmax(out[0])] print(f"ONNX+OpenCV smoke test: input=random noise -> pred={pred} (expected: arbitrary)", flush=True) print(f"\nDone. Labels: {LABELS}", flush=True) if __name__ == "__main__": main()