CNN trained on PrIMuS crops achieves 100% on held-out test set. Includes training pipeline, evaluation script, extraction tools, and saved model weights. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
214 lines
5.7 KiB
Python
214 lines
5.7 KiB
Python
"""Extract accidental crops from the PrIMuS dataset using Verovio re-rendering.
|
|
|
|
For each MEI file:
|
|
1. Render with Verovio to SVG (getting exact symbol positions)
|
|
2. Rasterize SVG to PNG with cairosvg
|
|
3. Parse SVG for accidental elements (sharp/flat/natural)
|
|
4. Crop each accidental from the rasterized image
|
|
5. Save organized by type: crops/{sharp,flat,natural}/
|
|
|
|
SMuFL glyph IDs:
|
|
E262 = sharp
|
|
E260 = flat
|
|
E261 = natural
|
|
"""
|
|
|
|
import glob
|
|
import os
|
|
import re
|
|
import sys
|
|
import traceback
|
|
from io import BytesIO
|
|
|
|
import cairosvg
|
|
import numpy as np
|
|
from PIL import Image
|
|
|
|
sys.stdout.reconfigure(encoding="utf-8")
|
|
|
|
DATASET_ROOT = r"C:\src\accidentals\dataset"
|
|
OUTPUT_ROOT = r"C:\src\accidentals\crops"
|
|
|
|
GLYPH_MAP = {"E262": "sharp", "E260": "flat", "E261": "natural"}
|
|
|
|
# Regex to find accidental <use> elements in Verovio SVG
|
|
ACCID_RE = re.compile(
|
|
r'class="(keyAccid|accid)"[^>]*>\s*'
|
|
r'<use xlink:href="#(E\d+)[^"]*"\s+'
|
|
r'transform="translate\((\d+),\s*(\d+)\)\s*scale\(([^,]+),\s*([^)]+)\)"'
|
|
)
|
|
SVG_DIM_RE = re.compile(r'width="(\d+)px" height="(\d+)px"')
|
|
|
|
# Crop parameters (in SVG pixel units, viewBox/10)
|
|
CROP_PAD_LEFT = 3
|
|
CROP_PAD_RIGHT = 22
|
|
CROP_PAD_TOP = 25
|
|
CROP_PAD_BOTTOM = 35
|
|
|
|
|
|
def setup_verovio():
|
|
import verovio
|
|
|
|
tk = verovio.toolkit()
|
|
tk.setOptions(
|
|
{
|
|
"adjustPageHeight": True,
|
|
"adjustPageWidth": True,
|
|
"header": "none",
|
|
"footer": "none",
|
|
"breaks": "none",
|
|
"scale": 100,
|
|
"pageMarginLeft": 0,
|
|
"pageMarginRight": 0,
|
|
"pageMarginTop": 0,
|
|
"pageMarginBottom": 0,
|
|
}
|
|
)
|
|
return tk
|
|
|
|
|
|
def extract_from_mei(tk, mei_path: str) -> list[dict]:
|
|
"""Extract accidental crops from a single MEI file.
|
|
|
|
Returns list of dicts with 'type', 'image' (PIL Image), 'context' (keyAccid/accid).
|
|
"""
|
|
with open(mei_path, "r", encoding="utf-8") as f:
|
|
mei_data = f.read()
|
|
|
|
tk.loadData(mei_data)
|
|
svg = tk.renderToSVG(1)
|
|
|
|
if not svg:
|
|
return []
|
|
|
|
# Get SVG pixel dimensions
|
|
dim_match = SVG_DIM_RE.search(svg)
|
|
if not dim_match:
|
|
return []
|
|
svg_w = int(dim_match.group(1))
|
|
svg_h = int(dim_match.group(2))
|
|
|
|
if svg_w == 0 or svg_h == 0:
|
|
return []
|
|
|
|
# Find accidentals in SVG
|
|
accid_positions = []
|
|
for match in ACCID_RE.finditer(svg):
|
|
cls = match.group(1)
|
|
glyph_id = match.group(2)
|
|
tx, ty = int(match.group(3)), int(match.group(4))
|
|
|
|
glyph_type = GLYPH_MAP.get(glyph_id)
|
|
if glyph_type is None:
|
|
continue
|
|
|
|
# viewBox coords -> pixel coords (viewBox is 10x pixel dimensions)
|
|
px = tx / 10.0
|
|
py = ty / 10.0
|
|
|
|
accid_positions.append(
|
|
{"type": glyph_type, "context": cls, "px": px, "py": py}
|
|
)
|
|
|
|
if not accid_positions:
|
|
return []
|
|
|
|
# Rasterize SVG to PNG
|
|
try:
|
|
png_bytes = cairosvg.svg2png(
|
|
bytestring=svg.encode("utf-8"),
|
|
output_width=svg_w,
|
|
output_height=svg_h,
|
|
)
|
|
except Exception:
|
|
return []
|
|
|
|
img = Image.open(BytesIO(png_bytes)).convert("L")
|
|
|
|
# Crop each accidental
|
|
results = []
|
|
for acc in accid_positions:
|
|
x1 = max(0, int(acc["px"] - CROP_PAD_LEFT))
|
|
y1 = max(0, int(acc["py"] - CROP_PAD_TOP))
|
|
x2 = min(svg_w, int(acc["px"] + CROP_PAD_RIGHT))
|
|
y2 = min(svg_h, int(acc["py"] + CROP_PAD_BOTTOM))
|
|
|
|
if x2 - x1 < 5 or y2 - y1 < 5:
|
|
continue
|
|
|
|
crop = img.crop((x1, y1, x2, y2))
|
|
results.append(
|
|
{
|
|
"type": acc["type"],
|
|
"context": acc["context"],
|
|
"image": crop,
|
|
}
|
|
)
|
|
|
|
return results
|
|
|
|
|
|
def main():
|
|
# Find all MEI files
|
|
mei_files = []
|
|
for pkg in ["package_aa", "package_ab"]:
|
|
mei_files.extend(
|
|
glob.glob(os.path.join(DATASET_ROOT, pkg, "*", "*.mei"))
|
|
)
|
|
mei_files.sort()
|
|
|
|
print(f"Found {len(mei_files)} MEI files")
|
|
|
|
# Create output directories
|
|
for label in GLYPH_MAP.values():
|
|
os.makedirs(os.path.join(OUTPUT_ROOT, label), exist_ok=True)
|
|
|
|
tk = setup_verovio()
|
|
|
|
# Process with multiple Verovio fonts for variety
|
|
fonts = ["Leipzig", "Bravura", "Gootville"]
|
|
counts = {"sharp": 0, "flat": 0, "natural": 0}
|
|
errors = 0
|
|
total_processed = 0
|
|
|
|
for font_idx, font in enumerate(fonts):
|
|
tk.setOptions({"font": font})
|
|
print(f"\n=== Font: {font} (pass {font_idx+1}/{len(fonts)}) ===")
|
|
|
|
for i, mei_path in enumerate(mei_files):
|
|
if i % 5000 == 0:
|
|
print(
|
|
f" [{font}] {i}/{len(mei_files)} "
|
|
f"sharp={counts['sharp']} flat={counts['flat']} "
|
|
f"natural={counts['natural']} errors={errors}"
|
|
)
|
|
|
|
try:
|
|
results = extract_from_mei(tk, mei_path)
|
|
except Exception:
|
|
errors += 1
|
|
if errors <= 5:
|
|
traceback.print_exc()
|
|
continue
|
|
|
|
sample_id = os.path.basename(os.path.dirname(mei_path))
|
|
for j, res in enumerate(results):
|
|
label = res["type"]
|
|
fname = f"{sample_id}_f{font_idx}_{j}.png"
|
|
out_path = os.path.join(OUTPUT_ROOT, label, fname)
|
|
res["image"].save(out_path)
|
|
counts[label] += 1
|
|
|
|
total_processed += 1
|
|
|
|
print(f"\n=== Done ===")
|
|
print(f"Processed: {total_processed} MEI files across {len(fonts)} fonts")
|
|
print(f"Errors: {errors}")
|
|
print(f"Extracted crops:")
|
|
for label, count in sorted(counts.items()):
|
|
print(f" {label:10s}: {count:7d}")
|
|
print(f"Output: {OUTPUT_ROOT}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|