Compare commits

...

2 commits

Author SHA1 Message Date
dullfig
85c8cfd8bb Add README with architecture, pipeline, and usage docs
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-02 08:04:03 -08:00
dullfig
afc16c2bbb Initial commit: accidental classifier (sharp/flat/natural)
CNN trained on PrIMuS crops achieves 100% on held-out test set.
Includes training pipeline, evaluation script, extraction tools,
and saved model weights.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-02 08:03:56 -08:00
9 changed files with 1151 additions and 1 deletions

7
.gitignore vendored Normal file
View file

@ -0,0 +1,7 @@
__pycache__/
.claude/
crops/
dataset/
debug/
*.tgz
nul

View file

@ -1,2 +1,83 @@
# accidentals
# Accidental Classifier
A small CNN that classifies musical accidentals — **sharp**, **flat**, and **natural** — from cropped grayscale images of engraved scores.
Trained on crops extracted from the [PrIMuS dataset](https://grfia.dlsi.ua.es/primus/) and achieves **100% accuracy** on a held-out test set (750 samples).
```
Confusion matrix (rows=true, cols=predicted):
flat natural sharp
flat 250 0 0
natural 0 250 0
sharp 0 0 250
```
## Architecture
| | |
|---|---|
| **Input** | 1 &times; 40 &times; 40 grayscale |
| **Backbone** | 3 conv blocks (16 &rarr; 32 &rarr; 64 channels), each with BatchNorm + ReLU + MaxPool |
| **Head** | Dropout &rarr; FC(1600, 64) &rarr; ReLU &rarr; Dropout &rarr; FC(64, 3) |
| **Parameters** | ~115 k |
| **Output** | 3-class logits (flat / natural / sharp) |
A saved ONNX export is also produced by `train.py` for inference via OpenCV's `dnn` module.
## Repository layout
```
train.py Training pipeline (CNN + augmentation + weighted sampling)
eval.py Evaluate saved checkpoint on the held-out test set
extract_accidentals.py Extract accidental crops from PrIMuS MEI files via Verovio
extract_fast.py Faster single-font variant of the extractor
explore_dataset.py Explore PrIMuS agnostic encodings and image statistics
segment_test.py Connected-component symbol segmentation experiment
model/
accidental_classifier.pt Saved PyTorch weights (best validation accuracy)
```
## Pipeline
### 1. Extract crops
Render each PrIMuS MEI file with [Verovio](https://www.verovio.org/), locate accidental glyphs (SMuFL codepoints E260/E261/E262) in the SVG, rasterize with cairosvg, and crop each symbol into `crops/{flat,natural,sharp}/`.
```bash
python extract_fast.py
```
### 2. Train
Loads all crops into RAM, applies a stratified train/val/test split (seed 42), trains with data augmentation (random affine + erasing) and class-balanced sampling, and saves the best checkpoint.
```bash
python train.py
```
### 3. Evaluate
Reproduces the identical test split and reports accuracy, per-class metrics, the confusion matrix, and sample misclassifications.
```bash
python eval.py
```
## Dependencies
- Python 3.10+
- PyTorch
- torchvision
- NumPy
- Pillow
- OpenCV (`cv2`) — for ONNX verification and segmentation experiments
- Verovio (Python bindings) — for crop extraction only
- cairosvg — for crop extraction only
## Data
Training data is derived from [PrIMuS](https://grfia.dlsi.ua.es/primus/) (Calvo-Zaragoza & Rizo, 2018). The extracted crops and raw dataset are not included in this repository due to size.
## License
This project is released into the public domain.

148
eval.py Normal file
View file

@ -0,0 +1,148 @@
"""Evaluate the saved accidental classifier on the held-out test set.
Uses the same seed (42), split ratios, and data loading as train.py
so the test set is identical to what training used.
"""
import os
import random
from pathlib import Path
from collections import defaultdict
import numpy as np
import torch
from PIL import Image
from train import (
AccidentalCNN,
CROPS_ROOT,
OUTPUT_DIR,
LABELS,
IMG_SIZE,
SEED,
VAL_SPLIT,
TEST_SPLIT,
MAX_PER_CLASS,
InMemoryDataset,
load_all_into_ram,
split_data,
)
from torch.utils.data import DataLoader
MODEL_PATH = OUTPUT_DIR / "accidental_classifier.pt"
def collect_file_paths():
"""Collect file paths in the same order as load_all_into_ram.
Must be called after seeding so random.shuffle matches train.py.
"""
paths = []
for label_idx, label_name in enumerate(LABELS):
label_dir = CROPS_ROOT / label_name
files = [f for f in os.listdir(label_dir) if f.endswith(".png")]
random.shuffle(files)
if MAX_PER_CLASS and len(files) > MAX_PER_CLASS:
files = files[:MAX_PER_CLASS]
for fname in files:
paths.append(str(label_dir / fname))
return paths
def main():
# Reproduce the exact same RNG state and splits as train.py
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
# Collect file paths (consumes same random state as load_all_into_ram)
file_paths = collect_file_paths()
# Re-seed and reload to get tensors (load_all_into_ram shuffles again internally)
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
images, labels = load_all_into_ram()
_, _, test_idx = split_data(images, labels)
print(f"\nTest set size: {len(test_idx)}")
for i, name in enumerate(LABELS):
n = (labels[test_idx] == i).sum().item()
print(f" {name:8s}: {n}")
# Load model
device = torch.device("cpu")
model = AccidentalCNN().to(device)
model.load_state_dict(torch.load(MODEL_PATH, weights_only=True))
model.eval()
print(f"\nLoaded weights from {MODEL_PATH}")
# Run inference
test_ds = InMemoryDataset(images[test_idx], labels[test_idx], augment=False)
test_loader = DataLoader(test_ds, batch_size=256, shuffle=False, num_workers=0)
all_preds = []
all_labels = []
with torch.no_grad():
for imgs, lbls in test_loader:
imgs = imgs.to(device)
outputs = model(imgs)
_, predicted = outputs.max(1)
all_preds.append(predicted.cpu())
all_labels.append(lbls)
all_preds = torch.cat(all_preds)
all_labels = torch.cat(all_labels)
# Overall accuracy
correct = (all_preds == all_labels).sum().item()
total = len(all_labels)
print(f"\n=== Test Results ===")
print(f"Overall accuracy: {correct}/{total} ({correct/total:.1%})")
# Per-class accuracy
print(f"\nPer-class accuracy:")
for i, name in enumerate(LABELS):
mask = all_labels == i
cls_total = mask.sum().item()
cls_correct = (all_preds[mask] == i).sum().item()
acc = cls_correct / cls_total if cls_total > 0 else 0
print(f" {name:8s}: {cls_correct}/{cls_total} ({acc:.1%})")
# Confusion matrix
n_classes = len(LABELS)
cm = torch.zeros(n_classes, n_classes, dtype=torch.long)
for t, p in zip(all_labels, all_preds):
cm[t][p] += 1
print(f"\nConfusion matrix (rows=true, cols=predicted):")
header = " " + "".join(f"{name:>9s}" for name in LABELS)
print(header)
for i, name in enumerate(LABELS):
row = f" {name:8s}" + "".join(f"{cm[i][j].item():9d}" for j in range(n_classes))
print(row)
# Sample misclassifications
test_indices = test_idx.tolist()
misclassified = []
for i in range(len(all_preds)):
if all_preds[i] != all_labels[i]:
orig_idx = test_indices[i]
misclassified.append((
file_paths[orig_idx],
LABELS[all_labels[i].item()],
LABELS[all_preds[i].item()],
))
print(f"\nMisclassifications: {len(misclassified)}/{total}")
if misclassified:
print(f"\nSample misclassifications (up to 10):")
for path, true_lbl, pred_lbl in misclassified[:10]:
print(f" {path}")
print(f" true={true_lbl} predicted={pred_lbl}")
if __name__ == "__main__":
main()

102
explore_dataset.py Normal file
View file

@ -0,0 +1,102 @@
"""Explore the PrIMuS dataset to understand accidental distribution and image structure."""
import glob
import os
from collections import Counter
from PIL import Image
import numpy as np
DATASET_ROOT = r"C:\src\accidentals\dataset"
def parse_agnostic(path: str) -> list[str]:
"""Parse an agnostic encoding file into a list of tokens."""
with open(path, "r", encoding="utf-8") as f:
return f.read().strip().split("\t")
def main():
# Find all agnostic files
patterns = [
os.path.join(DATASET_ROOT, "package_aa", "*", "*.agnostic"),
os.path.join(DATASET_ROOT, "package_ab", "*", "*.agnostic"),
]
agnostic_files = []
for p in patterns:
agnostic_files.extend(glob.glob(p))
print(f"Total incipits: {len(agnostic_files)}")
# Count accidental tokens
accidental_type_counts = Counter() # sharp/flat/natural
accidental_full_counts = Counter() # full token like accidental.sharp-L5
incipits_with_accidentals = 0
incipits_with_inline_accidentals = 0 # accidentals that aren't in key sig
all_symbol_types = Counter()
total_accidentals = 0
for path in agnostic_files:
tokens = parse_agnostic(path)
has_any_accidental = False
has_inline = False
past_time_sig = False
for tok in tokens:
# Track symbol types (just the prefix)
base = tok.split("-")[0] if "-" in tok else tok
all_symbol_types[base] += 1
if tok.startswith("digit."):
past_time_sig = True
if tok.startswith("accidental."):
has_any_accidental = True
total_accidentals += 1
# Extract type: accidental.sharp, accidental.flat, etc.
acc_type = tok.split("-")[0] # e.g. "accidental.sharp"
accidental_type_counts[acc_type] += 1
accidental_full_counts[tok] += 1
if past_time_sig:
has_inline = True
if has_any_accidental:
incipits_with_accidentals += 1
if has_inline:
incipits_with_inline_accidentals += 1
print(f"\n=== Accidental Statistics ===")
print(f"Total accidental tokens: {total_accidentals}")
print(f"Incipits with any accidentals: {incipits_with_accidentals} / {len(agnostic_files)} ({100*incipits_with_accidentals/len(agnostic_files):.1f}%)")
print(f"Incipits with inline accidentals: {incipits_with_inline_accidentals} / {len(agnostic_files)} ({100*incipits_with_inline_accidentals/len(agnostic_files):.1f}%)")
print(f"\n=== Accidental Type Counts ===")
for acc_type, count in accidental_type_counts.most_common():
print(f" {acc_type:25s} {count:7d}")
print(f"\n=== Top 20 Accidental Positions ===")
for tok, count in accidental_full_counts.most_common(20):
print(f" {tok:30s} {count:7d}")
print(f"\n=== Top 30 Symbol Types ===")
for sym, count in all_symbol_types.most_common(30):
print(f" {sym:30s} {count:7d}")
# Image statistics from a sample
print(f"\n=== Image Statistics (sample of 500) ===")
png_files = glob.glob(os.path.join(DATASET_ROOT, "package_aa", "*", "*.png"))[:500]
widths, heights = [], []
for f in png_files:
im = Image.open(f)
widths.append(im.size[0])
heights.append(im.size[1])
widths = np.array(widths)
heights = np.array(heights)
print(f" Width: min={widths.min()}, max={widths.max()}, mean={widths.mean():.0f}, std={widths.std():.0f}")
print(f" Height: min={heights.min()}, max={heights.max()}, mean={heights.mean():.0f}, std={heights.std():.0f}")
print(f" Modes: {Counter(Image.open(f).mode for f in png_files[:50])}")
if __name__ == "__main__":
main()

214
extract_accidentals.py Normal file
View file

@ -0,0 +1,214 @@
"""Extract accidental crops from the PrIMuS dataset using Verovio re-rendering.
For each MEI file:
1. Render with Verovio to SVG (getting exact symbol positions)
2. Rasterize SVG to PNG with cairosvg
3. Parse SVG for accidental elements (sharp/flat/natural)
4. Crop each accidental from the rasterized image
5. Save organized by type: crops/{sharp,flat,natural}/
SMuFL glyph IDs:
E262 = sharp
E260 = flat
E261 = natural
"""
import glob
import os
import re
import sys
import traceback
from io import BytesIO
import cairosvg
import numpy as np
from PIL import Image
sys.stdout.reconfigure(encoding="utf-8")
DATASET_ROOT = r"C:\src\accidentals\dataset"
OUTPUT_ROOT = r"C:\src\accidentals\crops"
GLYPH_MAP = {"E262": "sharp", "E260": "flat", "E261": "natural"}
# Regex to find accidental <use> elements in Verovio SVG
ACCID_RE = re.compile(
r'class="(keyAccid|accid)"[^>]*>\s*'
r'<use xlink:href="#(E\d+)[^"]*"\s+'
r'transform="translate\((\d+),\s*(\d+)\)\s*scale\(([^,]+),\s*([^)]+)\)"'
)
SVG_DIM_RE = re.compile(r'width="(\d+)px" height="(\d+)px"')
# Crop parameters (in SVG pixel units, viewBox/10)
CROP_PAD_LEFT = 3
CROP_PAD_RIGHT = 22
CROP_PAD_TOP = 25
CROP_PAD_BOTTOM = 35
def setup_verovio():
import verovio
tk = verovio.toolkit()
tk.setOptions(
{
"adjustPageHeight": True,
"adjustPageWidth": True,
"header": "none",
"footer": "none",
"breaks": "none",
"scale": 100,
"pageMarginLeft": 0,
"pageMarginRight": 0,
"pageMarginTop": 0,
"pageMarginBottom": 0,
}
)
return tk
def extract_from_mei(tk, mei_path: str) -> list[dict]:
"""Extract accidental crops from a single MEI file.
Returns list of dicts with 'type', 'image' (PIL Image), 'context' (keyAccid/accid).
"""
with open(mei_path, "r", encoding="utf-8") as f:
mei_data = f.read()
tk.loadData(mei_data)
svg = tk.renderToSVG(1)
if not svg:
return []
# Get SVG pixel dimensions
dim_match = SVG_DIM_RE.search(svg)
if not dim_match:
return []
svg_w = int(dim_match.group(1))
svg_h = int(dim_match.group(2))
if svg_w == 0 or svg_h == 0:
return []
# Find accidentals in SVG
accid_positions = []
for match in ACCID_RE.finditer(svg):
cls = match.group(1)
glyph_id = match.group(2)
tx, ty = int(match.group(3)), int(match.group(4))
glyph_type = GLYPH_MAP.get(glyph_id)
if glyph_type is None:
continue
# viewBox coords -> pixel coords (viewBox is 10x pixel dimensions)
px = tx / 10.0
py = ty / 10.0
accid_positions.append(
{"type": glyph_type, "context": cls, "px": px, "py": py}
)
if not accid_positions:
return []
# Rasterize SVG to PNG
try:
png_bytes = cairosvg.svg2png(
bytestring=svg.encode("utf-8"),
output_width=svg_w,
output_height=svg_h,
)
except Exception:
return []
img = Image.open(BytesIO(png_bytes)).convert("L")
# Crop each accidental
results = []
for acc in accid_positions:
x1 = max(0, int(acc["px"] - CROP_PAD_LEFT))
y1 = max(0, int(acc["py"] - CROP_PAD_TOP))
x2 = min(svg_w, int(acc["px"] + CROP_PAD_RIGHT))
y2 = min(svg_h, int(acc["py"] + CROP_PAD_BOTTOM))
if x2 - x1 < 5 or y2 - y1 < 5:
continue
crop = img.crop((x1, y1, x2, y2))
results.append(
{
"type": acc["type"],
"context": acc["context"],
"image": crop,
}
)
return results
def main():
# Find all MEI files
mei_files = []
for pkg in ["package_aa", "package_ab"]:
mei_files.extend(
glob.glob(os.path.join(DATASET_ROOT, pkg, "*", "*.mei"))
)
mei_files.sort()
print(f"Found {len(mei_files)} MEI files")
# Create output directories
for label in GLYPH_MAP.values():
os.makedirs(os.path.join(OUTPUT_ROOT, label), exist_ok=True)
tk = setup_verovio()
# Process with multiple Verovio fonts for variety
fonts = ["Leipzig", "Bravura", "Gootville"]
counts = {"sharp": 0, "flat": 0, "natural": 0}
errors = 0
total_processed = 0
for font_idx, font in enumerate(fonts):
tk.setOptions({"font": font})
print(f"\n=== Font: {font} (pass {font_idx+1}/{len(fonts)}) ===")
for i, mei_path in enumerate(mei_files):
if i % 5000 == 0:
print(
f" [{font}] {i}/{len(mei_files)} "
f"sharp={counts['sharp']} flat={counts['flat']} "
f"natural={counts['natural']} errors={errors}"
)
try:
results = extract_from_mei(tk, mei_path)
except Exception:
errors += 1
if errors <= 5:
traceback.print_exc()
continue
sample_id = os.path.basename(os.path.dirname(mei_path))
for j, res in enumerate(results):
label = res["type"]
fname = f"{sample_id}_f{font_idx}_{j}.png"
out_path = os.path.join(OUTPUT_ROOT, label, fname)
res["image"].save(out_path)
counts[label] += 1
total_processed += 1
print(f"\n=== Done ===")
print(f"Processed: {total_processed} MEI files across {len(fonts)} fonts")
print(f"Errors: {errors}")
print(f"Extracted crops:")
for label, count in sorted(counts.items()):
print(f" {label:10s}: {count:7d}")
print(f"Output: {OUTPUT_ROOT}")
if __name__ == "__main__":
main()

161
extract_fast.py Normal file
View file

@ -0,0 +1,161 @@
"""Fast accidental extraction from PrIMuS using Verovio + cairosvg.
Optimizations:
- Pre-filter MEI files using agnostic encoding (skip files with no accidentals)
- Single font pass (Leipzig) - augmentation can be done at training time
- Flush output for progress monitoring
"""
import glob
import os
import re
import sys
import time
import traceback
from io import BytesIO
import cairosvg
import numpy as np
from PIL import Image
sys.stdout.reconfigure(encoding="utf-8")
DATASET_ROOT = r"C:\src\accidentals\dataset"
OUTPUT_ROOT = r"C:\src\accidentals\crops"
GLYPH_MAP = {"E262": "sharp", "E260": "flat", "E261": "natural"}
ACCID_RE = re.compile(
r'class="(keyAccid|accid)"[^>]*>\s*'
r'<use xlink:href="#(E\d+)[^"]*"\s+'
r'transform="translate\((\d+),\s*(\d+)\)\s*scale\(([^,]+),\s*([^)]+)\)"'
)
SVG_DIM_RE = re.compile(r'width="(\d+)px" height="(\d+)px"')
def has_accidentals(agnostic_path: str) -> bool:
"""Quick check if an agnostic file contains any accidental tokens."""
with open(agnostic_path, "r", encoding="utf-8") as f:
return "accidental." in f.read()
def main():
import verovio
# Glob all MEI files at once (much faster than iterating dirs on Windows)
print("Finding MEI files...", flush=True)
mei_files = sorted(
glob.glob(os.path.join(DATASET_ROOT, "package_*", "*", "*.mei"))
)
print(f"Found {len(mei_files)} MEI files (skipping pre-filter)", flush=True)
# Create output directories
for label in GLYPH_MAP.values():
os.makedirs(os.path.join(OUTPUT_ROOT, label), exist_ok=True)
# Setup Verovio
tk = verovio.toolkit()
tk.setOptions({
"adjustPageHeight": True,
"adjustPageWidth": True,
"header": "none",
"footer": "none",
"breaks": "none",
"scale": 100,
"pageMarginLeft": 0,
"pageMarginRight": 0,
"pageMarginTop": 0,
"pageMarginBottom": 0,
"font": "Leipzig",
})
counts = {"sharp": 0, "flat": 0, "natural": 0}
errors = 0
t0 = time.time()
for i, mei_path in enumerate(mei_files):
if i % 1000 == 0:
elapsed = time.time() - t0
rate = i / elapsed if elapsed > 0 else 0
eta = (len(mei_files) - i) / rate if rate > 0 else 0
print(
f"[{i:6d}/{len(mei_files)}] "
f"sharp={counts['sharp']} flat={counts['flat']} "
f"natural={counts['natural']} "
f"err={errors} "
f"rate={rate:.0f}/s eta={eta/60:.0f}m",
flush=True,
)
try:
with open(mei_path, "r", encoding="utf-8") as f:
mei_data = f.read()
tk.loadData(mei_data)
svg = tk.renderToSVG(1)
if not svg:
continue
dim_match = SVG_DIM_RE.search(svg)
if not dim_match:
continue
svg_w = int(dim_match.group(1))
svg_h = int(dim_match.group(2))
if svg_w == 0 or svg_h == 0:
continue
# Find accidentals
accids = []
for match in ACCID_RE.finditer(svg):
glyph_type = GLYPH_MAP.get(match.group(2))
if glyph_type:
accids.append({
"type": glyph_type,
"px": int(match.group(3)) / 10.0,
"py": int(match.group(4)) / 10.0,
})
if not accids:
continue
# Rasterize
png_bytes = cairosvg.svg2png(
bytestring=svg.encode("utf-8"),
output_width=svg_w,
output_height=svg_h,
)
img_rgba = Image.open(BytesIO(png_bytes))
white_bg = Image.new("RGBA", img_rgba.size, (255, 255, 255, 255))
img = Image.alpha_composite(white_bg, img_rgba).convert("L")
sample_id = os.path.basename(os.path.dirname(mei_path))
for j, acc in enumerate(accids):
x1 = max(0, int(acc["px"] - 3))
y1 = max(0, int(acc["py"] - 25))
x2 = min(svg_w, int(acc["px"] + 22))
y2 = min(svg_h, int(acc["py"] + 35))
if x2 - x1 < 5 or y2 - y1 < 5:
continue
crop = img.crop((x1, y1, x2, y2))
fname = f"{sample_id}_{j}.png"
crop.save(os.path.join(OUTPUT_ROOT, acc["type"], fname))
counts[acc["type"]] += 1
except Exception:
errors += 1
if errors <= 5:
traceback.print_exc()
elapsed = time.time() - t0
print(f"\n=== Done in {elapsed/60:.1f} minutes ===", flush=True)
print(f"Errors: {errors}")
print(f"Crops extracted:")
for label, count in sorted(counts.items()):
print(f" {label:10s}: {count:7d}")
print(f"Total: {sum(counts.values())}")
print(f"Output: {OUTPUT_ROOT}")
if __name__ == "__main__":
main()

Binary file not shown.

122
segment_test.py Normal file
View file

@ -0,0 +1,122 @@
"""Test segmenting symbols from a PrIMuS image using connected components."""
import cv2
import numpy as np
from PIL import Image
def segment_symbols(png_path: str):
"""Segment a PrIMuS image into individual symbols."""
# Load as grayscale
img = cv2.imread(png_path, cv2.IMREAD_GRAYSCALE)
if img is None:
# PrIMuS uses palette mode PNGs, load via PIL first
pil_img = Image.open(png_path).convert("L")
img = np.array(pil_img)
h, w = img.shape
print(f"Image size: {w} x {h}")
# Binarize (black ink on white background)
_, binary = cv2.threshold(img, 128, 255, cv2.THRESH_BINARY_INV)
# Remove staff lines using horizontal morphology
# Detect staff lines
horizontal_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (w // 4, 1))
staff_lines = cv2.morphologyEx(binary, cv2.MORPH_OPEN, horizontal_kernel)
# Remove staff lines from binary image
no_lines = binary - staff_lines
# Clean up with small morphological close to reconnect broken symbols
close_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (2, 2))
no_lines = cv2.morphologyEx(no_lines, cv2.MORPH_CLOSE, close_kernel)
# Find connected components
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(
no_lines, connectivity=8
)
# Filter out tiny noise (area < 10 pixels)
symbols = []
for i in range(1, num_labels): # skip background
x, y, sw, sh, area = stats[i]
if area < 10:
continue
symbols.append({
"x": x, "y": y, "w": sw, "h": sh,
"area": area, "cx": centroids[i][0],
})
# Sort by x position (left to right)
symbols.sort(key=lambda s: s["x"])
# Group nearby components into symbol clusters
# (a single symbol like a sharp may have multiple disconnected parts)
clusters = []
for s in symbols:
if clusters and s["x"] < clusters[-1]["x2"] + 5:
# Merge into previous cluster
c = clusters[-1]
c["x1"] = min(c["x1"], s["x"])
c["y1"] = min(c["y1"], s["y"])
c["x2"] = max(c["x2"], s["x"] + s["w"])
c["y2"] = max(c["y2"], s["y"] + s["h"])
c["components"].append(s)
else:
clusters.append({
"x1": s["x"], "y1": s["y"],
"x2": s["x"] + s["w"], "y2": s["y"] + s["h"],
"components": [s],
})
print(f"Found {len(symbols)} components -> {len(clusters)} symbol clusters")
for i, c in enumerate(clusters):
w = c["x2"] - c["x1"]
h = c["y2"] - c["y1"]
print(f" Cluster {i:3d}: x={c['x1']:5d}-{c['x2']:5d} y={c['y1']:3d}-{c['y2']:3d} size={w:3d}x{h:3d}")
return clusters, img, no_lines
def match_tokens_to_clusters(agnostic_path: str, clusters: list):
"""Match agnostic tokens to segmented clusters."""
with open(agnostic_path, "r") as f:
tokens = f.read().strip().split("\t")
print(f"\nAgnostic tokens ({len(tokens)}):")
for i, tok in enumerate(tokens):
print(f" {i:3d}: {tok}")
print(f"\nClusters: {len(clusters)}")
print(f"Tokens: {len(tokens)}")
return tokens
if __name__ == "__main__":
sample_dir = r"C:\src\accidentals\dataset\package_aa\000141058-1_1_1"
sample_id = "000141058-1_1_1"
png_path = f"{sample_dir}/{sample_id}.png"
agnostic_path = f"{sample_dir}/{sample_id}.agnostic"
clusters, img, no_lines = segment_symbols(png_path)
tokens = match_tokens_to_clusters(agnostic_path, clusters)
# Save debug images
debug_dir = r"C:\src\accidentals\debug"
import os
os.makedirs(debug_dir, exist_ok=True)
cv2.imwrite(f"{debug_dir}/no_lines.png", 255 - no_lines)
# Draw bounding boxes on original
vis = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
for i, c in enumerate(clusters):
color = (0, 0, 255)
cv2.rectangle(vis, (c["x1"], c["y1"]), (c["x2"], c["y2"]), color, 1)
cv2.putText(vis, str(i), (c["x1"], c["y1"] - 2),
cv2.FONT_HERSHEY_SIMPLEX, 0.3, color, 1)
cv2.imwrite(f"{debug_dir}/segmented.png", vis)
print(f"\nDebug images saved to {debug_dir}/")

315
train.py Normal file
View file

@ -0,0 +1,315 @@
"""Train a small CNN to classify accidentals (sharp/flat/natural).
Outputs an ONNX model for inference via OpenCV's dnn module.
"""
import os
import sys
import time
import random
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from PIL import Image
from torchvision import transforms
sys.stdout.reconfigure(encoding="utf-8")
CROPS_ROOT = Path(r"C:\src\accidentals\crops")
OUTPUT_DIR = Path(r"C:\src\accidentals\model")
LABELS = ["flat", "natural", "sharp"] # alphabetical → class indices 0,1,2
IMG_SIZE = (40, 40) # square, small enough for fast inference
BATCH_SIZE = 256
NUM_EPOCHS = 30
LR = 1e-3
SEED = 42
VAL_SPLIT = 0.1
TEST_SPLIT = 0.05
MAX_PER_CLASS = 5000 # subsample: more than enough for this easy problem
class InMemoryDataset(Dataset):
"""Dataset that holds all images as tensors in RAM."""
def __init__(self, images: torch.Tensor, labels: torch.Tensor, augment=False):
self.images = images # (N, 1, H, W) float32
self.labels = labels # (N,) long
self.augment = augment
self.affine = transforms.RandomAffine(
degrees=5, translate=(0.1, 0.1), scale=(0.9, 1.1)
)
self.erasing = transforms.RandomErasing(p=0.2, scale=(0.02, 0.08))
def __len__(self):
return len(self.labels)
def __getitem__(self, idx):
img = self.images[idx]
if self.augment:
img = self.affine(img)
img = self.erasing(img)
return img, self.labels[idx]
class AccidentalCNN(nn.Module):
"""Small CNN for 3-class accidental classification.
Input: 1x40x40 grayscale image
Architecture: 3 conv blocks + 1 FC layer
"""
def __init__(self):
super().__init__()
self.features = nn.Sequential(
# Block 1: 1 -> 16 channels, 40x40 -> 20x20
nn.Conv2d(1, 16, 3, padding=1),
nn.BatchNorm2d(16),
nn.ReLU(inplace=True),
nn.MaxPool2d(2),
# Block 2: 16 -> 32 channels, 20x20 -> 10x10
nn.Conv2d(16, 32, 3, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.MaxPool2d(2),
# Block 3: 32 -> 64 channels, 10x10 -> 5x5
nn.Conv2d(32, 64, 3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(2),
)
self.classifier = nn.Sequential(
nn.Dropout(0.3),
nn.Linear(64 * 5 * 5, 64),
nn.ReLU(inplace=True),
nn.Dropout(0.3),
nn.Linear(64, len(LABELS)),
)
def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x
def load_all_into_ram():
"""Load all images into RAM as tensors. Returns (images, labels) tensors."""
print("Loading all images into RAM...", flush=True)
all_images = []
all_labels = []
for label_idx, label_name in enumerate(LABELS):
label_dir = CROPS_ROOT / label_name
files = [f for f in os.listdir(label_dir) if f.endswith(".png")]
random.shuffle(files)
if MAX_PER_CLASS and len(files) > MAX_PER_CLASS:
files = files[:MAX_PER_CLASS]
print(f" {label_name}: {len(files)} files", flush=True)
for fname in files:
img = Image.open(label_dir / fname).convert("L")
img = img.resize(IMG_SIZE, Image.BILINEAR)
arr = np.array(img, dtype=np.float32) / 255.0
all_images.append(arr)
all_labels.append(label_idx)
# Stack into tensors: (N, 1, H, W)
images = torch.tensor(np.array(all_images)).unsqueeze(1)
labels = torch.tensor(all_labels, dtype=torch.long)
print(f" Total: {len(labels)} images, tensor shape: {images.shape}", flush=True)
print(f" Memory: {images.nbytes / 1024**2:.0f} MB", flush=True)
return images, labels
def split_data(images, labels):
"""Stratified train/val/test split. Returns index arrays."""
indices = torch.randperm(len(labels))
by_class = {i: [] for i in range(len(LABELS))}
for idx in indices:
by_class[labels[idx].item()].append(idx.item())
train_idx, val_idx, test_idx = [], [], []
for cls, idxs in by_class.items():
n = len(idxs)
n_test = max(1, int(n * TEST_SPLIT))
n_val = max(1, int(n * VAL_SPLIT))
test_idx.extend(idxs[:n_test])
val_idx.extend(idxs[n_test : n_test + n_val])
train_idx.extend(idxs[n_test + n_val :])
return (
torch.tensor(train_idx),
torch.tensor(val_idx),
torch.tensor(test_idx),
)
def make_weighted_sampler(labels):
"""Create a sampler that oversamples minority classes."""
class_counts = torch.bincount(labels, minlength=len(LABELS)).float()
weights = 1.0 / class_counts[labels]
return WeightedRandomSampler(weights, len(labels), replacement=True)
def evaluate(model, loader, device):
"""Evaluate model on a data loader."""
model.eval()
correct = 0
total = 0
class_correct = np.zeros(len(LABELS))
class_total = np.zeros(len(LABELS))
total_loss = 0.0
n_batches = 0
with torch.no_grad():
for imgs, lbls in loader:
imgs, lbls = imgs.to(device), lbls.to(device)
outputs = model(imgs)
loss = F.cross_entropy(outputs, lbls)
total_loss += loss.item()
n_batches += 1
_, predicted = outputs.max(1)
total += lbls.size(0)
correct += predicted.eq(lbls).sum().item()
for i in range(len(LABELS)):
mask = lbls == i
class_total[i] += mask.sum().item()
class_correct[i] += (predicted[mask] == i).sum().item()
acc = correct / total if total > 0 else 0
class_acc = np.where(class_total > 0, class_correct / class_total, 0)
avg_loss = total_loss / n_batches if n_batches > 0 else 0
return avg_loss, acc, class_acc
def main():
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
device = torch.device("cpu")
print(f"Device: {device}", flush=True)
# Load everything into RAM
images, labels = load_all_into_ram()
# Split
train_idx, val_idx, test_idx = split_data(images, labels)
print(f"\nTrain: {len(train_idx)} Val: {len(val_idx)} Test: {len(test_idx)}", flush=True)
for i, name in enumerate(LABELS):
nt = (labels[train_idx] == i).sum().item()
nv = (labels[val_idx] == i).sum().item()
ne = (labels[test_idx] == i).sum().item()
print(f" {name:8s}: train={nt:6d} val={nv:5d} test={ne:5d}", flush=True)
# Create datasets (in-memory, fast)
train_ds = InMemoryDataset(images[train_idx], labels[train_idx], augment=True)
val_ds = InMemoryDataset(images[val_idx], labels[val_idx], augment=False)
test_ds = InMemoryDataset(images[test_idx], labels[test_idx], augment=False)
# Weighted sampler for class imbalance
sampler = make_weighted_sampler(labels[train_idx])
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, sampler=sampler, num_workers=0)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
# Model
model = AccidentalCNN().to(device)
param_count = sum(p.numel() for p in model.parameters())
print(f"\nModel parameters: {param_count:,}", flush=True)
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode="max", factor=0.5, patience=3
)
# Training loop
best_val_acc = 0.0
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
best_model_path = OUTPUT_DIR / "accidental_classifier.pt"
print(f"\n{'Epoch':>5s} {'TrLoss':>7s} {'VaLoss':>7s} {'VaAcc':>6s} "
f"{'Flat':>5s} {'Nat':>5s} {'Sharp':>5s} {'LR':>8s} {'Time':>5s}", flush=True)
print("-" * 70, flush=True)
for epoch in range(1, NUM_EPOCHS + 1):
t0 = time.time()
model.train()
train_loss = 0.0
n_batches = 0
for imgs, lbls in train_loader:
imgs, lbls = imgs.to(device), lbls.to(device)
optimizer.zero_grad()
outputs = model(imgs)
loss = F.cross_entropy(outputs, lbls)
loss.backward()
optimizer.step()
train_loss += loss.item()
n_batches += 1
avg_train_loss = train_loss / n_batches
val_loss, val_acc, class_acc = evaluate(model, val_loader, device)
scheduler.step(val_acc)
elapsed = time.time() - t0
lr = optimizer.param_groups[0]["lr"]
print(
f"{epoch:5d} {avg_train_loss:7.4f} {val_loss:7.4f} {val_acc:6.1%} "
f"{class_acc[0]:5.1%} {class_acc[1]:5.1%} {class_acc[2]:5.1%} "
f"{lr:.6f} {elapsed:5.1f}s",
flush=True,
)
if val_acc > best_val_acc:
best_val_acc = val_acc
torch.save(model.state_dict(), best_model_path)
# Final evaluation on test set
print(f"\n=== Test Set Evaluation ===", flush=True)
model.load_state_dict(torch.load(best_model_path, weights_only=True))
test_loss, test_acc, test_class_acc = evaluate(model, test_loader, device)
print(f"Test accuracy: {test_acc:.1%}", flush=True)
for i, name in enumerate(LABELS):
print(f" {name:8s}: {test_class_acc[i]:.1%}", flush=True)
# Export to ONNX for OpenCV dnn inference
print(f"\nExporting to ONNX...", flush=True)
model.eval()
dummy = torch.randn(1, 1, *IMG_SIZE)
onnx_path = OUTPUT_DIR / "accidental_classifier.onnx"
torch.onnx.export(
model,
dummy,
str(onnx_path),
input_names=["image"],
output_names=["logits"],
dynamic_axes={"image": {0: "batch"}, "logits": {0: "batch"}},
opset_version=13,
)
print(f"Saved ONNX model: {onnx_path}", flush=True)
print(f"Saved PyTorch model: {best_model_path}", flush=True)
# Quick ONNX verification via OpenCV
import cv2
net = cv2.dnn.readNetFromONNX(str(onnx_path))
test_img = np.random.randint(0, 255, (IMG_SIZE[1], IMG_SIZE[0]), dtype=np.uint8)
blob = cv2.dnn.blobFromImage(test_img, 1.0 / 255.0, IMG_SIZE, 0, swapRB=False)
net.setInput(blob)
out = net.forward()
pred = LABELS[np.argmax(out[0])]
print(f"ONNX+OpenCV smoke test: input=random noise -> pred={pred} (expected: arbitrary)", flush=True)
print(f"\nDone. Labels: {LABELS}", flush=True)
if __name__ == "__main__":
main()