accidentals/extract_accidentals.py
dullfig afc16c2bbb Initial commit: accidental classifier (sharp/flat/natural)
CNN trained on PrIMuS crops achieves 100% on held-out test set.
Includes training pipeline, evaluation script, extraction tools,
and saved model weights.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-02 08:03:56 -08:00

214 lines
5.7 KiB
Python

"""Extract accidental crops from the PrIMuS dataset using Verovio re-rendering.
For each MEI file:
1. Render with Verovio to SVG (getting exact symbol positions)
2. Rasterize SVG to PNG with cairosvg
3. Parse SVG for accidental elements (sharp/flat/natural)
4. Crop each accidental from the rasterized image
5. Save organized by type: crops/{sharp,flat,natural}/
SMuFL glyph IDs:
E262 = sharp
E260 = flat
E261 = natural
"""
import glob
import os
import re
import sys
import traceback
from io import BytesIO
import cairosvg
import numpy as np
from PIL import Image
sys.stdout.reconfigure(encoding="utf-8")
DATASET_ROOT = r"C:\src\accidentals\dataset"
OUTPUT_ROOT = r"C:\src\accidentals\crops"
GLYPH_MAP = {"E262": "sharp", "E260": "flat", "E261": "natural"}
# Regex to find accidental <use> elements in Verovio SVG
ACCID_RE = re.compile(
r'class="(keyAccid|accid)"[^>]*>\s*'
r'<use xlink:href="#(E\d+)[^"]*"\s+'
r'transform="translate\((\d+),\s*(\d+)\)\s*scale\(([^,]+),\s*([^)]+)\)"'
)
SVG_DIM_RE = re.compile(r'width="(\d+)px" height="(\d+)px"')
# Crop parameters (in SVG pixel units, viewBox/10)
CROP_PAD_LEFT = 3
CROP_PAD_RIGHT = 22
CROP_PAD_TOP = 25
CROP_PAD_BOTTOM = 35
def setup_verovio():
import verovio
tk = verovio.toolkit()
tk.setOptions(
{
"adjustPageHeight": True,
"adjustPageWidth": True,
"header": "none",
"footer": "none",
"breaks": "none",
"scale": 100,
"pageMarginLeft": 0,
"pageMarginRight": 0,
"pageMarginTop": 0,
"pageMarginBottom": 0,
}
)
return tk
def extract_from_mei(tk, mei_path: str) -> list[dict]:
"""Extract accidental crops from a single MEI file.
Returns list of dicts with 'type', 'image' (PIL Image), 'context' (keyAccid/accid).
"""
with open(mei_path, "r", encoding="utf-8") as f:
mei_data = f.read()
tk.loadData(mei_data)
svg = tk.renderToSVG(1)
if not svg:
return []
# Get SVG pixel dimensions
dim_match = SVG_DIM_RE.search(svg)
if not dim_match:
return []
svg_w = int(dim_match.group(1))
svg_h = int(dim_match.group(2))
if svg_w == 0 or svg_h == 0:
return []
# Find accidentals in SVG
accid_positions = []
for match in ACCID_RE.finditer(svg):
cls = match.group(1)
glyph_id = match.group(2)
tx, ty = int(match.group(3)), int(match.group(4))
glyph_type = GLYPH_MAP.get(glyph_id)
if glyph_type is None:
continue
# viewBox coords -> pixel coords (viewBox is 10x pixel dimensions)
px = tx / 10.0
py = ty / 10.0
accid_positions.append(
{"type": glyph_type, "context": cls, "px": px, "py": py}
)
if not accid_positions:
return []
# Rasterize SVG to PNG
try:
png_bytes = cairosvg.svg2png(
bytestring=svg.encode("utf-8"),
output_width=svg_w,
output_height=svg_h,
)
except Exception:
return []
img = Image.open(BytesIO(png_bytes)).convert("L")
# Crop each accidental
results = []
for acc in accid_positions:
x1 = max(0, int(acc["px"] - CROP_PAD_LEFT))
y1 = max(0, int(acc["py"] - CROP_PAD_TOP))
x2 = min(svg_w, int(acc["px"] + CROP_PAD_RIGHT))
y2 = min(svg_h, int(acc["py"] + CROP_PAD_BOTTOM))
if x2 - x1 < 5 or y2 - y1 < 5:
continue
crop = img.crop((x1, y1, x2, y2))
results.append(
{
"type": acc["type"],
"context": acc["context"],
"image": crop,
}
)
return results
def main():
# Find all MEI files
mei_files = []
for pkg in ["package_aa", "package_ab"]:
mei_files.extend(
glob.glob(os.path.join(DATASET_ROOT, pkg, "*", "*.mei"))
)
mei_files.sort()
print(f"Found {len(mei_files)} MEI files")
# Create output directories
for label in GLYPH_MAP.values():
os.makedirs(os.path.join(OUTPUT_ROOT, label), exist_ok=True)
tk = setup_verovio()
# Process with multiple Verovio fonts for variety
fonts = ["Leipzig", "Bravura", "Gootville"]
counts = {"sharp": 0, "flat": 0, "natural": 0}
errors = 0
total_processed = 0
for font_idx, font in enumerate(fonts):
tk.setOptions({"font": font})
print(f"\n=== Font: {font} (pass {font_idx+1}/{len(fonts)}) ===")
for i, mei_path in enumerate(mei_files):
if i % 5000 == 0:
print(
f" [{font}] {i}/{len(mei_files)} "
f"sharp={counts['sharp']} flat={counts['flat']} "
f"natural={counts['natural']} errors={errors}"
)
try:
results = extract_from_mei(tk, mei_path)
except Exception:
errors += 1
if errors <= 5:
traceback.print_exc()
continue
sample_id = os.path.basename(os.path.dirname(mei_path))
for j, res in enumerate(results):
label = res["type"]
fname = f"{sample_id}_f{font_idx}_{j}.png"
out_path = os.path.join(OUTPUT_ROOT, label, fname)
res["image"].save(out_path)
counts[label] += 1
total_processed += 1
print(f"\n=== Done ===")
print(f"Processed: {total_processed} MEI files across {len(fonts)} fonts")
print(f"Errors: {errors}")
print(f"Extracted crops:")
for label, count in sorted(counts.items()):
print(f" {label:10s}: {count:7d}")
print(f"Output: {OUTPUT_ROOT}")
if __name__ == "__main__":
main()