CNN trained on PrIMuS crops achieves 100% on held-out test set. Includes training pipeline, evaluation script, extraction tools, and saved model weights. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
161 lines
4.9 KiB
Python
161 lines
4.9 KiB
Python
"""Fast accidental extraction from PrIMuS using Verovio + cairosvg.
|
|
|
|
Optimizations:
|
|
- Pre-filter MEI files using agnostic encoding (skip files with no accidentals)
|
|
- Single font pass (Leipzig) - augmentation can be done at training time
|
|
- Flush output for progress monitoring
|
|
"""
|
|
|
|
import glob
|
|
import os
|
|
import re
|
|
import sys
|
|
import time
|
|
import traceback
|
|
from io import BytesIO
|
|
|
|
import cairosvg
|
|
import numpy as np
|
|
from PIL import Image
|
|
|
|
sys.stdout.reconfigure(encoding="utf-8")
|
|
|
|
DATASET_ROOT = r"C:\src\accidentals\dataset"
|
|
OUTPUT_ROOT = r"C:\src\accidentals\crops"
|
|
|
|
GLYPH_MAP = {"E262": "sharp", "E260": "flat", "E261": "natural"}
|
|
|
|
ACCID_RE = re.compile(
|
|
r'class="(keyAccid|accid)"[^>]*>\s*'
|
|
r'<use xlink:href="#(E\d+)[^"]*"\s+'
|
|
r'transform="translate\((\d+),\s*(\d+)\)\s*scale\(([^,]+),\s*([^)]+)\)"'
|
|
)
|
|
SVG_DIM_RE = re.compile(r'width="(\d+)px" height="(\d+)px"')
|
|
|
|
|
|
def has_accidentals(agnostic_path: str) -> bool:
|
|
"""Quick check if an agnostic file contains any accidental tokens."""
|
|
with open(agnostic_path, "r", encoding="utf-8") as f:
|
|
return "accidental." in f.read()
|
|
|
|
|
|
def main():
|
|
import verovio
|
|
|
|
# Glob all MEI files at once (much faster than iterating dirs on Windows)
|
|
print("Finding MEI files...", flush=True)
|
|
mei_files = sorted(
|
|
glob.glob(os.path.join(DATASET_ROOT, "package_*", "*", "*.mei"))
|
|
)
|
|
print(f"Found {len(mei_files)} MEI files (skipping pre-filter)", flush=True)
|
|
|
|
# Create output directories
|
|
for label in GLYPH_MAP.values():
|
|
os.makedirs(os.path.join(OUTPUT_ROOT, label), exist_ok=True)
|
|
|
|
# Setup Verovio
|
|
tk = verovio.toolkit()
|
|
tk.setOptions({
|
|
"adjustPageHeight": True,
|
|
"adjustPageWidth": True,
|
|
"header": "none",
|
|
"footer": "none",
|
|
"breaks": "none",
|
|
"scale": 100,
|
|
"pageMarginLeft": 0,
|
|
"pageMarginRight": 0,
|
|
"pageMarginTop": 0,
|
|
"pageMarginBottom": 0,
|
|
"font": "Leipzig",
|
|
})
|
|
|
|
counts = {"sharp": 0, "flat": 0, "natural": 0}
|
|
errors = 0
|
|
t0 = time.time()
|
|
|
|
for i, mei_path in enumerate(mei_files):
|
|
if i % 1000 == 0:
|
|
elapsed = time.time() - t0
|
|
rate = i / elapsed if elapsed > 0 else 0
|
|
eta = (len(mei_files) - i) / rate if rate > 0 else 0
|
|
print(
|
|
f"[{i:6d}/{len(mei_files)}] "
|
|
f"sharp={counts['sharp']} flat={counts['flat']} "
|
|
f"natural={counts['natural']} "
|
|
f"err={errors} "
|
|
f"rate={rate:.0f}/s eta={eta/60:.0f}m",
|
|
flush=True,
|
|
)
|
|
|
|
try:
|
|
with open(mei_path, "r", encoding="utf-8") as f:
|
|
mei_data = f.read()
|
|
|
|
tk.loadData(mei_data)
|
|
svg = tk.renderToSVG(1)
|
|
if not svg:
|
|
continue
|
|
|
|
dim_match = SVG_DIM_RE.search(svg)
|
|
if not dim_match:
|
|
continue
|
|
svg_w = int(dim_match.group(1))
|
|
svg_h = int(dim_match.group(2))
|
|
if svg_w == 0 or svg_h == 0:
|
|
continue
|
|
|
|
# Find accidentals
|
|
accids = []
|
|
for match in ACCID_RE.finditer(svg):
|
|
glyph_type = GLYPH_MAP.get(match.group(2))
|
|
if glyph_type:
|
|
accids.append({
|
|
"type": glyph_type,
|
|
"px": int(match.group(3)) / 10.0,
|
|
"py": int(match.group(4)) / 10.0,
|
|
})
|
|
|
|
if not accids:
|
|
continue
|
|
|
|
# Rasterize
|
|
png_bytes = cairosvg.svg2png(
|
|
bytestring=svg.encode("utf-8"),
|
|
output_width=svg_w,
|
|
output_height=svg_h,
|
|
)
|
|
img_rgba = Image.open(BytesIO(png_bytes))
|
|
white_bg = Image.new("RGBA", img_rgba.size, (255, 255, 255, 255))
|
|
img = Image.alpha_composite(white_bg, img_rgba).convert("L")
|
|
|
|
sample_id = os.path.basename(os.path.dirname(mei_path))
|
|
for j, acc in enumerate(accids):
|
|
x1 = max(0, int(acc["px"] - 3))
|
|
y1 = max(0, int(acc["py"] - 25))
|
|
x2 = min(svg_w, int(acc["px"] + 22))
|
|
y2 = min(svg_h, int(acc["py"] + 35))
|
|
if x2 - x1 < 5 or y2 - y1 < 5:
|
|
continue
|
|
|
|
crop = img.crop((x1, y1, x2, y2))
|
|
fname = f"{sample_id}_{j}.png"
|
|
crop.save(os.path.join(OUTPUT_ROOT, acc["type"], fname))
|
|
counts[acc["type"]] += 1
|
|
|
|
except Exception:
|
|
errors += 1
|
|
if errors <= 5:
|
|
traceback.print_exc()
|
|
|
|
elapsed = time.time() - t0
|
|
print(f"\n=== Done in {elapsed/60:.1f} minutes ===", flush=True)
|
|
print(f"Errors: {errors}")
|
|
print(f"Crops extracted:")
|
|
for label, count in sorted(counts.items()):
|
|
print(f" {label:10s}: {count:7d}")
|
|
print(f"Total: {sum(counts.values())}")
|
|
print(f"Output: {OUTPUT_ROOT}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|