accidentals/extract_fast.py
dullfig ebc925482e Initial commit: accidental classifier (sharp/flat/natural)
CNN trained on PrIMuS crops achieves 100% on held-out test set.
Includes training pipeline, evaluation script, extraction tools,
and saved model weights.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-02 08:01:37 -08:00

161 lines
4.9 KiB
Python

"""Fast accidental extraction from PrIMuS using Verovio + cairosvg.
Optimizations:
- Pre-filter MEI files using agnostic encoding (skip files with no accidentals)
- Single font pass (Leipzig) - augmentation can be done at training time
- Flush output for progress monitoring
"""
import glob
import os
import re
import sys
import time
import traceback
from io import BytesIO
import cairosvg
import numpy as np
from PIL import Image
sys.stdout.reconfigure(encoding="utf-8")
DATASET_ROOT = r"C:\src\accidentals\dataset"
OUTPUT_ROOT = r"C:\src\accidentals\crops"
GLYPH_MAP = {"E262": "sharp", "E260": "flat", "E261": "natural"}
ACCID_RE = re.compile(
r'class="(keyAccid|accid)"[^>]*>\s*'
r'<use xlink:href="#(E\d+)[^"]*"\s+'
r'transform="translate\((\d+),\s*(\d+)\)\s*scale\(([^,]+),\s*([^)]+)\)"'
)
SVG_DIM_RE = re.compile(r'width="(\d+)px" height="(\d+)px"')
def has_accidentals(agnostic_path: str) -> bool:
"""Quick check if an agnostic file contains any accidental tokens."""
with open(agnostic_path, "r", encoding="utf-8") as f:
return "accidental." in f.read()
def main():
import verovio
# Glob all MEI files at once (much faster than iterating dirs on Windows)
print("Finding MEI files...", flush=True)
mei_files = sorted(
glob.glob(os.path.join(DATASET_ROOT, "package_*", "*", "*.mei"))
)
print(f"Found {len(mei_files)} MEI files (skipping pre-filter)", flush=True)
# Create output directories
for label in GLYPH_MAP.values():
os.makedirs(os.path.join(OUTPUT_ROOT, label), exist_ok=True)
# Setup Verovio
tk = verovio.toolkit()
tk.setOptions({
"adjustPageHeight": True,
"adjustPageWidth": True,
"header": "none",
"footer": "none",
"breaks": "none",
"scale": 100,
"pageMarginLeft": 0,
"pageMarginRight": 0,
"pageMarginTop": 0,
"pageMarginBottom": 0,
"font": "Leipzig",
})
counts = {"sharp": 0, "flat": 0, "natural": 0}
errors = 0
t0 = time.time()
for i, mei_path in enumerate(mei_files):
if i % 1000 == 0:
elapsed = time.time() - t0
rate = i / elapsed if elapsed > 0 else 0
eta = (len(mei_files) - i) / rate if rate > 0 else 0
print(
f"[{i:6d}/{len(mei_files)}] "
f"sharp={counts['sharp']} flat={counts['flat']} "
f"natural={counts['natural']} "
f"err={errors} "
f"rate={rate:.0f}/s eta={eta/60:.0f}m",
flush=True,
)
try:
with open(mei_path, "r", encoding="utf-8") as f:
mei_data = f.read()
tk.loadData(mei_data)
svg = tk.renderToSVG(1)
if not svg:
continue
dim_match = SVG_DIM_RE.search(svg)
if not dim_match:
continue
svg_w = int(dim_match.group(1))
svg_h = int(dim_match.group(2))
if svg_w == 0 or svg_h == 0:
continue
# Find accidentals
accids = []
for match in ACCID_RE.finditer(svg):
glyph_type = GLYPH_MAP.get(match.group(2))
if glyph_type:
accids.append({
"type": glyph_type,
"px": int(match.group(3)) / 10.0,
"py": int(match.group(4)) / 10.0,
})
if not accids:
continue
# Rasterize
png_bytes = cairosvg.svg2png(
bytestring=svg.encode("utf-8"),
output_width=svg_w,
output_height=svg_h,
)
img_rgba = Image.open(BytesIO(png_bytes))
white_bg = Image.new("RGBA", img_rgba.size, (255, 255, 255, 255))
img = Image.alpha_composite(white_bg, img_rgba).convert("L")
sample_id = os.path.basename(os.path.dirname(mei_path))
for j, acc in enumerate(accids):
x1 = max(0, int(acc["px"] - 3))
y1 = max(0, int(acc["py"] - 25))
x2 = min(svg_w, int(acc["px"] + 22))
y2 = min(svg_h, int(acc["py"] + 35))
if x2 - x1 < 5 or y2 - y1 < 5:
continue
crop = img.crop((x1, y1, x2, y2))
fname = f"{sample_id}_{j}.png"
crop.save(os.path.join(OUTPUT_ROOT, acc["type"], fname))
counts[acc["type"]] += 1
except Exception:
errors += 1
if errors <= 5:
traceback.print_exc()
elapsed = time.time() - t0
print(f"\n=== Done in {elapsed/60:.1f} minutes ===", flush=True)
print(f"Errors: {errors}")
print(f"Crops extracted:")
for label, count in sorted(counts.items()):
print(f" {label:10s}: {count:7d}")
print(f"Total: {sum(counts.values())}")
print(f"Output: {OUTPUT_ROOT}")
if __name__ == "__main__":
main()