CNN trained on PrIMuS crops achieves 100% on held-out test set. Includes training pipeline, evaluation script, extraction tools, and saved model weights. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
122 lines
4.1 KiB
Python
122 lines
4.1 KiB
Python
"""Test segmenting symbols from a PrIMuS image using connected components."""
|
|
|
|
import cv2
|
|
import numpy as np
|
|
from PIL import Image
|
|
|
|
|
|
def segment_symbols(png_path: str):
|
|
"""Segment a PrIMuS image into individual symbols."""
|
|
# Load as grayscale
|
|
img = cv2.imread(png_path, cv2.IMREAD_GRAYSCALE)
|
|
if img is None:
|
|
# PrIMuS uses palette mode PNGs, load via PIL first
|
|
pil_img = Image.open(png_path).convert("L")
|
|
img = np.array(pil_img)
|
|
|
|
h, w = img.shape
|
|
print(f"Image size: {w} x {h}")
|
|
|
|
# Binarize (black ink on white background)
|
|
_, binary = cv2.threshold(img, 128, 255, cv2.THRESH_BINARY_INV)
|
|
|
|
# Remove staff lines using horizontal morphology
|
|
# Detect staff lines
|
|
horizontal_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (w // 4, 1))
|
|
staff_lines = cv2.morphologyEx(binary, cv2.MORPH_OPEN, horizontal_kernel)
|
|
|
|
# Remove staff lines from binary image
|
|
no_lines = binary - staff_lines
|
|
|
|
# Clean up with small morphological close to reconnect broken symbols
|
|
close_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (2, 2))
|
|
no_lines = cv2.morphologyEx(no_lines, cv2.MORPH_CLOSE, close_kernel)
|
|
|
|
# Find connected components
|
|
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(
|
|
no_lines, connectivity=8
|
|
)
|
|
|
|
# Filter out tiny noise (area < 10 pixels)
|
|
symbols = []
|
|
for i in range(1, num_labels): # skip background
|
|
x, y, sw, sh, area = stats[i]
|
|
if area < 10:
|
|
continue
|
|
symbols.append({
|
|
"x": x, "y": y, "w": sw, "h": sh,
|
|
"area": area, "cx": centroids[i][0],
|
|
})
|
|
|
|
# Sort by x position (left to right)
|
|
symbols.sort(key=lambda s: s["x"])
|
|
|
|
# Group nearby components into symbol clusters
|
|
# (a single symbol like a sharp may have multiple disconnected parts)
|
|
clusters = []
|
|
for s in symbols:
|
|
if clusters and s["x"] < clusters[-1]["x2"] + 5:
|
|
# Merge into previous cluster
|
|
c = clusters[-1]
|
|
c["x1"] = min(c["x1"], s["x"])
|
|
c["y1"] = min(c["y1"], s["y"])
|
|
c["x2"] = max(c["x2"], s["x"] + s["w"])
|
|
c["y2"] = max(c["y2"], s["y"] + s["h"])
|
|
c["components"].append(s)
|
|
else:
|
|
clusters.append({
|
|
"x1": s["x"], "y1": s["y"],
|
|
"x2": s["x"] + s["w"], "y2": s["y"] + s["h"],
|
|
"components": [s],
|
|
})
|
|
|
|
print(f"Found {len(symbols)} components -> {len(clusters)} symbol clusters")
|
|
for i, c in enumerate(clusters):
|
|
w = c["x2"] - c["x1"]
|
|
h = c["y2"] - c["y1"]
|
|
print(f" Cluster {i:3d}: x={c['x1']:5d}-{c['x2']:5d} y={c['y1']:3d}-{c['y2']:3d} size={w:3d}x{h:3d}")
|
|
|
|
return clusters, img, no_lines
|
|
|
|
|
|
def match_tokens_to_clusters(agnostic_path: str, clusters: list):
|
|
"""Match agnostic tokens to segmented clusters."""
|
|
with open(agnostic_path, "r") as f:
|
|
tokens = f.read().strip().split("\t")
|
|
|
|
print(f"\nAgnostic tokens ({len(tokens)}):")
|
|
for i, tok in enumerate(tokens):
|
|
print(f" {i:3d}: {tok}")
|
|
|
|
print(f"\nClusters: {len(clusters)}")
|
|
print(f"Tokens: {len(tokens)}")
|
|
|
|
return tokens
|
|
|
|
|
|
if __name__ == "__main__":
|
|
sample_dir = r"C:\src\accidentals\dataset\package_aa\000141058-1_1_1"
|
|
sample_id = "000141058-1_1_1"
|
|
|
|
png_path = f"{sample_dir}/{sample_id}.png"
|
|
agnostic_path = f"{sample_dir}/{sample_id}.agnostic"
|
|
|
|
clusters, img, no_lines = segment_symbols(png_path)
|
|
tokens = match_tokens_to_clusters(agnostic_path, clusters)
|
|
|
|
# Save debug images
|
|
debug_dir = r"C:\src\accidentals\debug"
|
|
import os
|
|
os.makedirs(debug_dir, exist_ok=True)
|
|
|
|
cv2.imwrite(f"{debug_dir}/no_lines.png", 255 - no_lines)
|
|
|
|
# Draw bounding boxes on original
|
|
vis = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
|
for i, c in enumerate(clusters):
|
|
color = (0, 0, 255)
|
|
cv2.rectangle(vis, (c["x1"], c["y1"]), (c["x2"], c["y2"]), color, 1)
|
|
cv2.putText(vis, str(i), (c["x1"], c["y1"] - 2),
|
|
cv2.FONT_HERSHEY_SIMPLEX, 0.3, color, 1)
|
|
cv2.imwrite(f"{debug_dir}/segmented.png", vis)
|
|
print(f"\nDebug images saved to {debug_dir}/")
|