diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..362a3ad --- /dev/null +++ b/.gitignore @@ -0,0 +1,7 @@ +__pycache__/ +.claude/ +crops/ +dataset/ +debug/ +*.tgz +nul diff --git a/eval.py b/eval.py new file mode 100644 index 0000000..25ecf41 --- /dev/null +++ b/eval.py @@ -0,0 +1,148 @@ +"""Evaluate the saved accidental classifier on the held-out test set. + +Uses the same seed (42), split ratios, and data loading as train.py +so the test set is identical to what training used. +""" + +import os +import random +from pathlib import Path +from collections import defaultdict + +import numpy as np +import torch +from PIL import Image + +from train import ( + AccidentalCNN, + CROPS_ROOT, + OUTPUT_DIR, + LABELS, + IMG_SIZE, + SEED, + VAL_SPLIT, + TEST_SPLIT, + MAX_PER_CLASS, + InMemoryDataset, + load_all_into_ram, + split_data, +) +from torch.utils.data import DataLoader + +MODEL_PATH = OUTPUT_DIR / "accidental_classifier.pt" + + +def collect_file_paths(): + """Collect file paths in the same order as load_all_into_ram. + + Must be called after seeding so random.shuffle matches train.py. + """ + paths = [] + for label_idx, label_name in enumerate(LABELS): + label_dir = CROPS_ROOT / label_name + files = [f for f in os.listdir(label_dir) if f.endswith(".png")] + random.shuffle(files) + if MAX_PER_CLASS and len(files) > MAX_PER_CLASS: + files = files[:MAX_PER_CLASS] + for fname in files: + paths.append(str(label_dir / fname)) + return paths + + +def main(): + # Reproduce the exact same RNG state and splits as train.py + random.seed(SEED) + np.random.seed(SEED) + torch.manual_seed(SEED) + + # Collect file paths (consumes same random state as load_all_into_ram) + file_paths = collect_file_paths() + + # Re-seed and reload to get tensors (load_all_into_ram shuffles again internally) + random.seed(SEED) + np.random.seed(SEED) + torch.manual_seed(SEED) + + images, labels = load_all_into_ram() + _, _, test_idx = split_data(images, labels) + + print(f"\nTest set size: {len(test_idx)}") + for i, name in enumerate(LABELS): + n = (labels[test_idx] == i).sum().item() + print(f" {name:8s}: {n}") + + # Load model + device = torch.device("cpu") + model = AccidentalCNN().to(device) + model.load_state_dict(torch.load(MODEL_PATH, weights_only=True)) + model.eval() + print(f"\nLoaded weights from {MODEL_PATH}") + + # Run inference + test_ds = InMemoryDataset(images[test_idx], labels[test_idx], augment=False) + test_loader = DataLoader(test_ds, batch_size=256, shuffle=False, num_workers=0) + + all_preds = [] + all_labels = [] + + with torch.no_grad(): + for imgs, lbls in test_loader: + imgs = imgs.to(device) + outputs = model(imgs) + _, predicted = outputs.max(1) + all_preds.append(predicted.cpu()) + all_labels.append(lbls) + + all_preds = torch.cat(all_preds) + all_labels = torch.cat(all_labels) + + # Overall accuracy + correct = (all_preds == all_labels).sum().item() + total = len(all_labels) + print(f"\n=== Test Results ===") + print(f"Overall accuracy: {correct}/{total} ({correct/total:.1%})") + + # Per-class accuracy + print(f"\nPer-class accuracy:") + for i, name in enumerate(LABELS): + mask = all_labels == i + cls_total = mask.sum().item() + cls_correct = (all_preds[mask] == i).sum().item() + acc = cls_correct / cls_total if cls_total > 0 else 0 + print(f" {name:8s}: {cls_correct}/{cls_total} ({acc:.1%})") + + # Confusion matrix + n_classes = len(LABELS) + cm = torch.zeros(n_classes, n_classes, dtype=torch.long) + for t, p in zip(all_labels, all_preds): + cm[t][p] += 1 + + print(f"\nConfusion matrix (rows=true, cols=predicted):") + header = " " + "".join(f"{name:>9s}" for name in LABELS) + print(header) + for i, name in enumerate(LABELS): + row = f" {name:8s}" + "".join(f"{cm[i][j].item():9d}" for j in range(n_classes)) + print(row) + + # Sample misclassifications + test_indices = test_idx.tolist() + misclassified = [] + for i in range(len(all_preds)): + if all_preds[i] != all_labels[i]: + orig_idx = test_indices[i] + misclassified.append(( + file_paths[orig_idx], + LABELS[all_labels[i].item()], + LABELS[all_preds[i].item()], + )) + + print(f"\nMisclassifications: {len(misclassified)}/{total}") + if misclassified: + print(f"\nSample misclassifications (up to 10):") + for path, true_lbl, pred_lbl in misclassified[:10]: + print(f" {path}") + print(f" true={true_lbl} predicted={pred_lbl}") + + +if __name__ == "__main__": + main() diff --git a/explore_dataset.py b/explore_dataset.py new file mode 100644 index 0000000..a493b41 --- /dev/null +++ b/explore_dataset.py @@ -0,0 +1,102 @@ +"""Explore the PrIMuS dataset to understand accidental distribution and image structure.""" + +import glob +import os +from collections import Counter +from PIL import Image +import numpy as np + +DATASET_ROOT = r"C:\src\accidentals\dataset" + + +def parse_agnostic(path: str) -> list[str]: + """Parse an agnostic encoding file into a list of tokens.""" + with open(path, "r", encoding="utf-8") as f: + return f.read().strip().split("\t") + + +def main(): + # Find all agnostic files + patterns = [ + os.path.join(DATASET_ROOT, "package_aa", "*", "*.agnostic"), + os.path.join(DATASET_ROOT, "package_ab", "*", "*.agnostic"), + ] + agnostic_files = [] + for p in patterns: + agnostic_files.extend(glob.glob(p)) + + print(f"Total incipits: {len(agnostic_files)}") + + # Count accidental tokens + accidental_type_counts = Counter() # sharp/flat/natural + accidental_full_counts = Counter() # full token like accidental.sharp-L5 + incipits_with_accidentals = 0 + incipits_with_inline_accidentals = 0 # accidentals that aren't in key sig + all_symbol_types = Counter() + total_accidentals = 0 + + for path in agnostic_files: + tokens = parse_agnostic(path) + has_any_accidental = False + has_inline = False + past_time_sig = False + + for tok in tokens: + # Track symbol types (just the prefix) + base = tok.split("-")[0] if "-" in tok else tok + all_symbol_types[base] += 1 + + if tok.startswith("digit."): + past_time_sig = True + + if tok.startswith("accidental."): + has_any_accidental = True + total_accidentals += 1 + # Extract type: accidental.sharp, accidental.flat, etc. + acc_type = tok.split("-")[0] # e.g. "accidental.sharp" + accidental_type_counts[acc_type] += 1 + accidental_full_counts[tok] += 1 + + if past_time_sig: + has_inline = True + + if has_any_accidental: + incipits_with_accidentals += 1 + if has_inline: + incipits_with_inline_accidentals += 1 + + print(f"\n=== Accidental Statistics ===") + print(f"Total accidental tokens: {total_accidentals}") + print(f"Incipits with any accidentals: {incipits_with_accidentals} / {len(agnostic_files)} ({100*incipits_with_accidentals/len(agnostic_files):.1f}%)") + print(f"Incipits with inline accidentals: {incipits_with_inline_accidentals} / {len(agnostic_files)} ({100*incipits_with_inline_accidentals/len(agnostic_files):.1f}%)") + + print(f"\n=== Accidental Type Counts ===") + for acc_type, count in accidental_type_counts.most_common(): + print(f" {acc_type:25s} {count:7d}") + + print(f"\n=== Top 20 Accidental Positions ===") + for tok, count in accidental_full_counts.most_common(20): + print(f" {tok:30s} {count:7d}") + + print(f"\n=== Top 30 Symbol Types ===") + for sym, count in all_symbol_types.most_common(30): + print(f" {sym:30s} {count:7d}") + + # Image statistics from a sample + print(f"\n=== Image Statistics (sample of 500) ===") + png_files = glob.glob(os.path.join(DATASET_ROOT, "package_aa", "*", "*.png"))[:500] + widths, heights = [], [] + for f in png_files: + im = Image.open(f) + widths.append(im.size[0]) + heights.append(im.size[1]) + + widths = np.array(widths) + heights = np.array(heights) + print(f" Width: min={widths.min()}, max={widths.max()}, mean={widths.mean():.0f}, std={widths.std():.0f}") + print(f" Height: min={heights.min()}, max={heights.max()}, mean={heights.mean():.0f}, std={heights.std():.0f}") + print(f" Modes: {Counter(Image.open(f).mode for f in png_files[:50])}") + + +if __name__ == "__main__": + main() diff --git a/extract_accidentals.py b/extract_accidentals.py new file mode 100644 index 0000000..8b72f67 --- /dev/null +++ b/extract_accidentals.py @@ -0,0 +1,214 @@ +"""Extract accidental crops from the PrIMuS dataset using Verovio re-rendering. + +For each MEI file: +1. Render with Verovio to SVG (getting exact symbol positions) +2. Rasterize SVG to PNG with cairosvg +3. Parse SVG for accidental elements (sharp/flat/natural) +4. Crop each accidental from the rasterized image +5. Save organized by type: crops/{sharp,flat,natural}/ + +SMuFL glyph IDs: + E262 = sharp + E260 = flat + E261 = natural +""" + +import glob +import os +import re +import sys +import traceback +from io import BytesIO + +import cairosvg +import numpy as np +from PIL import Image + +sys.stdout.reconfigure(encoding="utf-8") + +DATASET_ROOT = r"C:\src\accidentals\dataset" +OUTPUT_ROOT = r"C:\src\accidentals\crops" + +GLYPH_MAP = {"E262": "sharp", "E260": "flat", "E261": "natural"} + +# Regex to find accidental elements in Verovio SVG +ACCID_RE = re.compile( + r'class="(keyAccid|accid)"[^>]*>\s*' + r' list[dict]: + """Extract accidental crops from a single MEI file. + + Returns list of dicts with 'type', 'image' (PIL Image), 'context' (keyAccid/accid). + """ + with open(mei_path, "r", encoding="utf-8") as f: + mei_data = f.read() + + tk.loadData(mei_data) + svg = tk.renderToSVG(1) + + if not svg: + return [] + + # Get SVG pixel dimensions + dim_match = SVG_DIM_RE.search(svg) + if not dim_match: + return [] + svg_w = int(dim_match.group(1)) + svg_h = int(dim_match.group(2)) + + if svg_w == 0 or svg_h == 0: + return [] + + # Find accidentals in SVG + accid_positions = [] + for match in ACCID_RE.finditer(svg): + cls = match.group(1) + glyph_id = match.group(2) + tx, ty = int(match.group(3)), int(match.group(4)) + + glyph_type = GLYPH_MAP.get(glyph_id) + if glyph_type is None: + continue + + # viewBox coords -> pixel coords (viewBox is 10x pixel dimensions) + px = tx / 10.0 + py = ty / 10.0 + + accid_positions.append( + {"type": glyph_type, "context": cls, "px": px, "py": py} + ) + + if not accid_positions: + return [] + + # Rasterize SVG to PNG + try: + png_bytes = cairosvg.svg2png( + bytestring=svg.encode("utf-8"), + output_width=svg_w, + output_height=svg_h, + ) + except Exception: + return [] + + img = Image.open(BytesIO(png_bytes)).convert("L") + + # Crop each accidental + results = [] + for acc in accid_positions: + x1 = max(0, int(acc["px"] - CROP_PAD_LEFT)) + y1 = max(0, int(acc["py"] - CROP_PAD_TOP)) + x2 = min(svg_w, int(acc["px"] + CROP_PAD_RIGHT)) + y2 = min(svg_h, int(acc["py"] + CROP_PAD_BOTTOM)) + + if x2 - x1 < 5 or y2 - y1 < 5: + continue + + crop = img.crop((x1, y1, x2, y2)) + results.append( + { + "type": acc["type"], + "context": acc["context"], + "image": crop, + } + ) + + return results + + +def main(): + # Find all MEI files + mei_files = [] + for pkg in ["package_aa", "package_ab"]: + mei_files.extend( + glob.glob(os.path.join(DATASET_ROOT, pkg, "*", "*.mei")) + ) + mei_files.sort() + + print(f"Found {len(mei_files)} MEI files") + + # Create output directories + for label in GLYPH_MAP.values(): + os.makedirs(os.path.join(OUTPUT_ROOT, label), exist_ok=True) + + tk = setup_verovio() + + # Process with multiple Verovio fonts for variety + fonts = ["Leipzig", "Bravura", "Gootville"] + counts = {"sharp": 0, "flat": 0, "natural": 0} + errors = 0 + total_processed = 0 + + for font_idx, font in enumerate(fonts): + tk.setOptions({"font": font}) + print(f"\n=== Font: {font} (pass {font_idx+1}/{len(fonts)}) ===") + + for i, mei_path in enumerate(mei_files): + if i % 5000 == 0: + print( + f" [{font}] {i}/{len(mei_files)} " + f"sharp={counts['sharp']} flat={counts['flat']} " + f"natural={counts['natural']} errors={errors}" + ) + + try: + results = extract_from_mei(tk, mei_path) + except Exception: + errors += 1 + if errors <= 5: + traceback.print_exc() + continue + + sample_id = os.path.basename(os.path.dirname(mei_path)) + for j, res in enumerate(results): + label = res["type"] + fname = f"{sample_id}_f{font_idx}_{j}.png" + out_path = os.path.join(OUTPUT_ROOT, label, fname) + res["image"].save(out_path) + counts[label] += 1 + + total_processed += 1 + + print(f"\n=== Done ===") + print(f"Processed: {total_processed} MEI files across {len(fonts)} fonts") + print(f"Errors: {errors}") + print(f"Extracted crops:") + for label, count in sorted(counts.items()): + print(f" {label:10s}: {count:7d}") + print(f"Output: {OUTPUT_ROOT}") + + +if __name__ == "__main__": + main() diff --git a/extract_fast.py b/extract_fast.py new file mode 100644 index 0000000..2c805ca --- /dev/null +++ b/extract_fast.py @@ -0,0 +1,161 @@ +"""Fast accidental extraction from PrIMuS using Verovio + cairosvg. + +Optimizations: +- Pre-filter MEI files using agnostic encoding (skip files with no accidentals) +- Single font pass (Leipzig) - augmentation can be done at training time +- Flush output for progress monitoring +""" + +import glob +import os +import re +import sys +import time +import traceback +from io import BytesIO + +import cairosvg +import numpy as np +from PIL import Image + +sys.stdout.reconfigure(encoding="utf-8") + +DATASET_ROOT = r"C:\src\accidentals\dataset" +OUTPUT_ROOT = r"C:\src\accidentals\crops" + +GLYPH_MAP = {"E262": "sharp", "E260": "flat", "E261": "natural"} + +ACCID_RE = re.compile( + r'class="(keyAccid|accid)"[^>]*>\s*' + r' bool: + """Quick check if an agnostic file contains any accidental tokens.""" + with open(agnostic_path, "r", encoding="utf-8") as f: + return "accidental." in f.read() + + +def main(): + import verovio + + # Glob all MEI files at once (much faster than iterating dirs on Windows) + print("Finding MEI files...", flush=True) + mei_files = sorted( + glob.glob(os.path.join(DATASET_ROOT, "package_*", "*", "*.mei")) + ) + print(f"Found {len(mei_files)} MEI files (skipping pre-filter)", flush=True) + + # Create output directories + for label in GLYPH_MAP.values(): + os.makedirs(os.path.join(OUTPUT_ROOT, label), exist_ok=True) + + # Setup Verovio + tk = verovio.toolkit() + tk.setOptions({ + "adjustPageHeight": True, + "adjustPageWidth": True, + "header": "none", + "footer": "none", + "breaks": "none", + "scale": 100, + "pageMarginLeft": 0, + "pageMarginRight": 0, + "pageMarginTop": 0, + "pageMarginBottom": 0, + "font": "Leipzig", + }) + + counts = {"sharp": 0, "flat": 0, "natural": 0} + errors = 0 + t0 = time.time() + + for i, mei_path in enumerate(mei_files): + if i % 1000 == 0: + elapsed = time.time() - t0 + rate = i / elapsed if elapsed > 0 else 0 + eta = (len(mei_files) - i) / rate if rate > 0 else 0 + print( + f"[{i:6d}/{len(mei_files)}] " + f"sharp={counts['sharp']} flat={counts['flat']} " + f"natural={counts['natural']} " + f"err={errors} " + f"rate={rate:.0f}/s eta={eta/60:.0f}m", + flush=True, + ) + + try: + with open(mei_path, "r", encoding="utf-8") as f: + mei_data = f.read() + + tk.loadData(mei_data) + svg = tk.renderToSVG(1) + if not svg: + continue + + dim_match = SVG_DIM_RE.search(svg) + if not dim_match: + continue + svg_w = int(dim_match.group(1)) + svg_h = int(dim_match.group(2)) + if svg_w == 0 or svg_h == 0: + continue + + # Find accidentals + accids = [] + for match in ACCID_RE.finditer(svg): + glyph_type = GLYPH_MAP.get(match.group(2)) + if glyph_type: + accids.append({ + "type": glyph_type, + "px": int(match.group(3)) / 10.0, + "py": int(match.group(4)) / 10.0, + }) + + if not accids: + continue + + # Rasterize + png_bytes = cairosvg.svg2png( + bytestring=svg.encode("utf-8"), + output_width=svg_w, + output_height=svg_h, + ) + img_rgba = Image.open(BytesIO(png_bytes)) + white_bg = Image.new("RGBA", img_rgba.size, (255, 255, 255, 255)) + img = Image.alpha_composite(white_bg, img_rgba).convert("L") + + sample_id = os.path.basename(os.path.dirname(mei_path)) + for j, acc in enumerate(accids): + x1 = max(0, int(acc["px"] - 3)) + y1 = max(0, int(acc["py"] - 25)) + x2 = min(svg_w, int(acc["px"] + 22)) + y2 = min(svg_h, int(acc["py"] + 35)) + if x2 - x1 < 5 or y2 - y1 < 5: + continue + + crop = img.crop((x1, y1, x2, y2)) + fname = f"{sample_id}_{j}.png" + crop.save(os.path.join(OUTPUT_ROOT, acc["type"], fname)) + counts[acc["type"]] += 1 + + except Exception: + errors += 1 + if errors <= 5: + traceback.print_exc() + + elapsed = time.time() - t0 + print(f"\n=== Done in {elapsed/60:.1f} minutes ===", flush=True) + print(f"Errors: {errors}") + print(f"Crops extracted:") + for label, count in sorted(counts.items()): + print(f" {label:10s}: {count:7d}") + print(f"Total: {sum(counts.values())}") + print(f"Output: {OUTPUT_ROOT}") + + +if __name__ == "__main__": + main() diff --git a/model/accidental_classifier.pt b/model/accidental_classifier.pt new file mode 100644 index 0000000..37a211d Binary files /dev/null and b/model/accidental_classifier.pt differ diff --git a/segment_test.py b/segment_test.py new file mode 100644 index 0000000..7c29067 --- /dev/null +++ b/segment_test.py @@ -0,0 +1,122 @@ +"""Test segmenting symbols from a PrIMuS image using connected components.""" + +import cv2 +import numpy as np +from PIL import Image + + +def segment_symbols(png_path: str): + """Segment a PrIMuS image into individual symbols.""" + # Load as grayscale + img = cv2.imread(png_path, cv2.IMREAD_GRAYSCALE) + if img is None: + # PrIMuS uses palette mode PNGs, load via PIL first + pil_img = Image.open(png_path).convert("L") + img = np.array(pil_img) + + h, w = img.shape + print(f"Image size: {w} x {h}") + + # Binarize (black ink on white background) + _, binary = cv2.threshold(img, 128, 255, cv2.THRESH_BINARY_INV) + + # Remove staff lines using horizontal morphology + # Detect staff lines + horizontal_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (w // 4, 1)) + staff_lines = cv2.morphologyEx(binary, cv2.MORPH_OPEN, horizontal_kernel) + + # Remove staff lines from binary image + no_lines = binary - staff_lines + + # Clean up with small morphological close to reconnect broken symbols + close_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (2, 2)) + no_lines = cv2.morphologyEx(no_lines, cv2.MORPH_CLOSE, close_kernel) + + # Find connected components + num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats( + no_lines, connectivity=8 + ) + + # Filter out tiny noise (area < 10 pixels) + symbols = [] + for i in range(1, num_labels): # skip background + x, y, sw, sh, area = stats[i] + if area < 10: + continue + symbols.append({ + "x": x, "y": y, "w": sw, "h": sh, + "area": area, "cx": centroids[i][0], + }) + + # Sort by x position (left to right) + symbols.sort(key=lambda s: s["x"]) + + # Group nearby components into symbol clusters + # (a single symbol like a sharp may have multiple disconnected parts) + clusters = [] + for s in symbols: + if clusters and s["x"] < clusters[-1]["x2"] + 5: + # Merge into previous cluster + c = clusters[-1] + c["x1"] = min(c["x1"], s["x"]) + c["y1"] = min(c["y1"], s["y"]) + c["x2"] = max(c["x2"], s["x"] + s["w"]) + c["y2"] = max(c["y2"], s["y"] + s["h"]) + c["components"].append(s) + else: + clusters.append({ + "x1": s["x"], "y1": s["y"], + "x2": s["x"] + s["w"], "y2": s["y"] + s["h"], + "components": [s], + }) + + print(f"Found {len(symbols)} components -> {len(clusters)} symbol clusters") + for i, c in enumerate(clusters): + w = c["x2"] - c["x1"] + h = c["y2"] - c["y1"] + print(f" Cluster {i:3d}: x={c['x1']:5d}-{c['x2']:5d} y={c['y1']:3d}-{c['y2']:3d} size={w:3d}x{h:3d}") + + return clusters, img, no_lines + + +def match_tokens_to_clusters(agnostic_path: str, clusters: list): + """Match agnostic tokens to segmented clusters.""" + with open(agnostic_path, "r") as f: + tokens = f.read().strip().split("\t") + + print(f"\nAgnostic tokens ({len(tokens)}):") + for i, tok in enumerate(tokens): + print(f" {i:3d}: {tok}") + + print(f"\nClusters: {len(clusters)}") + print(f"Tokens: {len(tokens)}") + + return tokens + + +if __name__ == "__main__": + sample_dir = r"C:\src\accidentals\dataset\package_aa\000141058-1_1_1" + sample_id = "000141058-1_1_1" + + png_path = f"{sample_dir}/{sample_id}.png" + agnostic_path = f"{sample_dir}/{sample_id}.agnostic" + + clusters, img, no_lines = segment_symbols(png_path) + tokens = match_tokens_to_clusters(agnostic_path, clusters) + + # Save debug images + debug_dir = r"C:\src\accidentals\debug" + import os + os.makedirs(debug_dir, exist_ok=True) + + cv2.imwrite(f"{debug_dir}/no_lines.png", 255 - no_lines) + + # Draw bounding boxes on original + vis = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + for i, c in enumerate(clusters): + color = (0, 0, 255) + cv2.rectangle(vis, (c["x1"], c["y1"]), (c["x2"], c["y2"]), color, 1) + cv2.putText(vis, str(i), (c["x1"], c["y1"] - 2), + cv2.FONT_HERSHEY_SIMPLEX, 0.3, color, 1) + cv2.imwrite(f"{debug_dir}/segmented.png", vis) + print(f"\nDebug images saved to {debug_dir}/") diff --git a/train.py b/train.py new file mode 100644 index 0000000..899cad4 --- /dev/null +++ b/train.py @@ -0,0 +1,315 @@ +"""Train a small CNN to classify accidentals (sharp/flat/natural). + +Outputs an ONNX model for inference via OpenCV's dnn module. +""" + +import os +import sys +import time +import random +from pathlib import Path + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler +from PIL import Image +from torchvision import transforms + +sys.stdout.reconfigure(encoding="utf-8") + +CROPS_ROOT = Path(r"C:\src\accidentals\crops") +OUTPUT_DIR = Path(r"C:\src\accidentals\model") +LABELS = ["flat", "natural", "sharp"] # alphabetical → class indices 0,1,2 +IMG_SIZE = (40, 40) # square, small enough for fast inference +BATCH_SIZE = 256 +NUM_EPOCHS = 30 +LR = 1e-3 +SEED = 42 +VAL_SPLIT = 0.1 +TEST_SPLIT = 0.05 +MAX_PER_CLASS = 5000 # subsample: more than enough for this easy problem + + +class InMemoryDataset(Dataset): + """Dataset that holds all images as tensors in RAM.""" + + def __init__(self, images: torch.Tensor, labels: torch.Tensor, augment=False): + self.images = images # (N, 1, H, W) float32 + self.labels = labels # (N,) long + self.augment = augment + self.affine = transforms.RandomAffine( + degrees=5, translate=(0.1, 0.1), scale=(0.9, 1.1) + ) + self.erasing = transforms.RandomErasing(p=0.2, scale=(0.02, 0.08)) + + def __len__(self): + return len(self.labels) + + def __getitem__(self, idx): + img = self.images[idx] + if self.augment: + img = self.affine(img) + img = self.erasing(img) + return img, self.labels[idx] + + +class AccidentalCNN(nn.Module): + """Small CNN for 3-class accidental classification. + + Input: 1x40x40 grayscale image + Architecture: 3 conv blocks + 1 FC layer + """ + + def __init__(self): + super().__init__() + self.features = nn.Sequential( + # Block 1: 1 -> 16 channels, 40x40 -> 20x20 + nn.Conv2d(1, 16, 3, padding=1), + nn.BatchNorm2d(16), + nn.ReLU(inplace=True), + nn.MaxPool2d(2), + # Block 2: 16 -> 32 channels, 20x20 -> 10x10 + nn.Conv2d(16, 32, 3, padding=1), + nn.BatchNorm2d(32), + nn.ReLU(inplace=True), + nn.MaxPool2d(2), + # Block 3: 32 -> 64 channels, 10x10 -> 5x5 + nn.Conv2d(32, 64, 3, padding=1), + nn.BatchNorm2d(64), + nn.ReLU(inplace=True), + nn.MaxPool2d(2), + ) + self.classifier = nn.Sequential( + nn.Dropout(0.3), + nn.Linear(64 * 5 * 5, 64), + nn.ReLU(inplace=True), + nn.Dropout(0.3), + nn.Linear(64, len(LABELS)), + ) + + def forward(self, x): + x = self.features(x) + x = x.view(x.size(0), -1) + x = self.classifier(x) + return x + + +def load_all_into_ram(): + """Load all images into RAM as tensors. Returns (images, labels) tensors.""" + print("Loading all images into RAM...", flush=True) + all_images = [] + all_labels = [] + + for label_idx, label_name in enumerate(LABELS): + label_dir = CROPS_ROOT / label_name + files = [f for f in os.listdir(label_dir) if f.endswith(".png")] + random.shuffle(files) + if MAX_PER_CLASS and len(files) > MAX_PER_CLASS: + files = files[:MAX_PER_CLASS] + print(f" {label_name}: {len(files)} files", flush=True) + + for fname in files: + img = Image.open(label_dir / fname).convert("L") + img = img.resize(IMG_SIZE, Image.BILINEAR) + arr = np.array(img, dtype=np.float32) / 255.0 + all_images.append(arr) + all_labels.append(label_idx) + + # Stack into tensors: (N, 1, H, W) + images = torch.tensor(np.array(all_images)).unsqueeze(1) + labels = torch.tensor(all_labels, dtype=torch.long) + print(f" Total: {len(labels)} images, tensor shape: {images.shape}", flush=True) + print(f" Memory: {images.nbytes / 1024**2:.0f} MB", flush=True) + return images, labels + + +def split_data(images, labels): + """Stratified train/val/test split. Returns index arrays.""" + indices = torch.randperm(len(labels)) + + by_class = {i: [] for i in range(len(LABELS))} + for idx in indices: + by_class[labels[idx].item()].append(idx.item()) + + train_idx, val_idx, test_idx = [], [], [] + for cls, idxs in by_class.items(): + n = len(idxs) + n_test = max(1, int(n * TEST_SPLIT)) + n_val = max(1, int(n * VAL_SPLIT)) + test_idx.extend(idxs[:n_test]) + val_idx.extend(idxs[n_test : n_test + n_val]) + train_idx.extend(idxs[n_test + n_val :]) + + return ( + torch.tensor(train_idx), + torch.tensor(val_idx), + torch.tensor(test_idx), + ) + + +def make_weighted_sampler(labels): + """Create a sampler that oversamples minority classes.""" + class_counts = torch.bincount(labels, minlength=len(LABELS)).float() + weights = 1.0 / class_counts[labels] + return WeightedRandomSampler(weights, len(labels), replacement=True) + + +def evaluate(model, loader, device): + """Evaluate model on a data loader.""" + model.eval() + correct = 0 + total = 0 + class_correct = np.zeros(len(LABELS)) + class_total = np.zeros(len(LABELS)) + total_loss = 0.0 + n_batches = 0 + + with torch.no_grad(): + for imgs, lbls in loader: + imgs, lbls = imgs.to(device), lbls.to(device) + outputs = model(imgs) + loss = F.cross_entropy(outputs, lbls) + total_loss += loss.item() + n_batches += 1 + + _, predicted = outputs.max(1) + total += lbls.size(0) + correct += predicted.eq(lbls).sum().item() + + for i in range(len(LABELS)): + mask = lbls == i + class_total[i] += mask.sum().item() + class_correct[i] += (predicted[mask] == i).sum().item() + + acc = correct / total if total > 0 else 0 + class_acc = np.where(class_total > 0, class_correct / class_total, 0) + avg_loss = total_loss / n_batches if n_batches > 0 else 0 + return avg_loss, acc, class_acc + + +def main(): + random.seed(SEED) + np.random.seed(SEED) + torch.manual_seed(SEED) + + device = torch.device("cpu") + print(f"Device: {device}", flush=True) + + # Load everything into RAM + images, labels = load_all_into_ram() + + # Split + train_idx, val_idx, test_idx = split_data(images, labels) + print(f"\nTrain: {len(train_idx)} Val: {len(val_idx)} Test: {len(test_idx)}", flush=True) + for i, name in enumerate(LABELS): + nt = (labels[train_idx] == i).sum().item() + nv = (labels[val_idx] == i).sum().item() + ne = (labels[test_idx] == i).sum().item() + print(f" {name:8s}: train={nt:6d} val={nv:5d} test={ne:5d}", flush=True) + + # Create datasets (in-memory, fast) + train_ds = InMemoryDataset(images[train_idx], labels[train_idx], augment=True) + val_ds = InMemoryDataset(images[val_idx], labels[val_idx], augment=False) + test_ds = InMemoryDataset(images[test_idx], labels[test_idx], augment=False) + + # Weighted sampler for class imbalance + sampler = make_weighted_sampler(labels[train_idx]) + + train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, sampler=sampler, num_workers=0) + val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0) + test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0) + + # Model + model = AccidentalCNN().to(device) + param_count = sum(p.numel() for p in model.parameters()) + print(f"\nModel parameters: {param_count:,}", flush=True) + + optimizer = torch.optim.Adam(model.parameters(), lr=LR) + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( + optimizer, mode="max", factor=0.5, patience=3 + ) + + # Training loop + best_val_acc = 0.0 + OUTPUT_DIR.mkdir(parents=True, exist_ok=True) + best_model_path = OUTPUT_DIR / "accidental_classifier.pt" + + print(f"\n{'Epoch':>5s} {'TrLoss':>7s} {'VaLoss':>7s} {'VaAcc':>6s} " + f"{'Flat':>5s} {'Nat':>5s} {'Sharp':>5s} {'LR':>8s} {'Time':>5s}", flush=True) + print("-" * 70, flush=True) + + for epoch in range(1, NUM_EPOCHS + 1): + t0 = time.time() + model.train() + train_loss = 0.0 + n_batches = 0 + + for imgs, lbls in train_loader: + imgs, lbls = imgs.to(device), lbls.to(device) + optimizer.zero_grad() + outputs = model(imgs) + loss = F.cross_entropy(outputs, lbls) + loss.backward() + optimizer.step() + train_loss += loss.item() + n_batches += 1 + + avg_train_loss = train_loss / n_batches + val_loss, val_acc, class_acc = evaluate(model, val_loader, device) + scheduler.step(val_acc) + + elapsed = time.time() - t0 + lr = optimizer.param_groups[0]["lr"] + + print( + f"{epoch:5d} {avg_train_loss:7.4f} {val_loss:7.4f} {val_acc:6.1%} " + f"{class_acc[0]:5.1%} {class_acc[1]:5.1%} {class_acc[2]:5.1%} " + f"{lr:.6f} {elapsed:5.1f}s", + flush=True, + ) + + if val_acc > best_val_acc: + best_val_acc = val_acc + torch.save(model.state_dict(), best_model_path) + + # Final evaluation on test set + print(f"\n=== Test Set Evaluation ===", flush=True) + model.load_state_dict(torch.load(best_model_path, weights_only=True)) + test_loss, test_acc, test_class_acc = evaluate(model, test_loader, device) + print(f"Test accuracy: {test_acc:.1%}", flush=True) + for i, name in enumerate(LABELS): + print(f" {name:8s}: {test_class_acc[i]:.1%}", flush=True) + + # Export to ONNX for OpenCV dnn inference + print(f"\nExporting to ONNX...", flush=True) + model.eval() + dummy = torch.randn(1, 1, *IMG_SIZE) + onnx_path = OUTPUT_DIR / "accidental_classifier.onnx" + torch.onnx.export( + model, + dummy, + str(onnx_path), + input_names=["image"], + output_names=["logits"], + dynamic_axes={"image": {0: "batch"}, "logits": {0: "batch"}}, + opset_version=13, + ) + print(f"Saved ONNX model: {onnx_path}", flush=True) + print(f"Saved PyTorch model: {best_model_path}", flush=True) + + # Quick ONNX verification via OpenCV + import cv2 + net = cv2.dnn.readNetFromONNX(str(onnx_path)) + test_img = np.random.randint(0, 255, (IMG_SIZE[1], IMG_SIZE[0]), dtype=np.uint8) + blob = cv2.dnn.blobFromImage(test_img, 1.0 / 255.0, IMG_SIZE, 0, swapRB=False) + net.setInput(blob) + out = net.forward() + pred = LABELS[np.argmax(out[0])] + print(f"ONNX+OpenCV smoke test: input=random noise -> pred={pred} (expected: arbitrary)", flush=True) + print(f"\nDone. Labels: {LABELS}", flush=True) + + +if __name__ == "__main__": + main()