Initial commit: accidental classifier (sharp/flat/natural)
CNN trained on PrIMuS crops achieves 100% on held-out test set. Includes training pipeline, evaluation script, extraction tools, and saved model weights. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
parent
a7164caf08
commit
afc16c2bbb
8 changed files with 1069 additions and 0 deletions
7
.gitignore
vendored
Normal file
7
.gitignore
vendored
Normal file
|
|
@ -0,0 +1,7 @@
|
||||||
|
__pycache__/
|
||||||
|
.claude/
|
||||||
|
crops/
|
||||||
|
dataset/
|
||||||
|
debug/
|
||||||
|
*.tgz
|
||||||
|
nul
|
||||||
148
eval.py
Normal file
148
eval.py
Normal file
|
|
@ -0,0 +1,148 @@
|
||||||
|
"""Evaluate the saved accidental classifier on the held-out test set.
|
||||||
|
|
||||||
|
Uses the same seed (42), split ratios, and data loading as train.py
|
||||||
|
so the test set is identical to what training used.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
from pathlib import Path
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from train import (
|
||||||
|
AccidentalCNN,
|
||||||
|
CROPS_ROOT,
|
||||||
|
OUTPUT_DIR,
|
||||||
|
LABELS,
|
||||||
|
IMG_SIZE,
|
||||||
|
SEED,
|
||||||
|
VAL_SPLIT,
|
||||||
|
TEST_SPLIT,
|
||||||
|
MAX_PER_CLASS,
|
||||||
|
InMemoryDataset,
|
||||||
|
load_all_into_ram,
|
||||||
|
split_data,
|
||||||
|
)
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
MODEL_PATH = OUTPUT_DIR / "accidental_classifier.pt"
|
||||||
|
|
||||||
|
|
||||||
|
def collect_file_paths():
|
||||||
|
"""Collect file paths in the same order as load_all_into_ram.
|
||||||
|
|
||||||
|
Must be called after seeding so random.shuffle matches train.py.
|
||||||
|
"""
|
||||||
|
paths = []
|
||||||
|
for label_idx, label_name in enumerate(LABELS):
|
||||||
|
label_dir = CROPS_ROOT / label_name
|
||||||
|
files = [f for f in os.listdir(label_dir) if f.endswith(".png")]
|
||||||
|
random.shuffle(files)
|
||||||
|
if MAX_PER_CLASS and len(files) > MAX_PER_CLASS:
|
||||||
|
files = files[:MAX_PER_CLASS]
|
||||||
|
for fname in files:
|
||||||
|
paths.append(str(label_dir / fname))
|
||||||
|
return paths
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# Reproduce the exact same RNG state and splits as train.py
|
||||||
|
random.seed(SEED)
|
||||||
|
np.random.seed(SEED)
|
||||||
|
torch.manual_seed(SEED)
|
||||||
|
|
||||||
|
# Collect file paths (consumes same random state as load_all_into_ram)
|
||||||
|
file_paths = collect_file_paths()
|
||||||
|
|
||||||
|
# Re-seed and reload to get tensors (load_all_into_ram shuffles again internally)
|
||||||
|
random.seed(SEED)
|
||||||
|
np.random.seed(SEED)
|
||||||
|
torch.manual_seed(SEED)
|
||||||
|
|
||||||
|
images, labels = load_all_into_ram()
|
||||||
|
_, _, test_idx = split_data(images, labels)
|
||||||
|
|
||||||
|
print(f"\nTest set size: {len(test_idx)}")
|
||||||
|
for i, name in enumerate(LABELS):
|
||||||
|
n = (labels[test_idx] == i).sum().item()
|
||||||
|
print(f" {name:8s}: {n}")
|
||||||
|
|
||||||
|
# Load model
|
||||||
|
device = torch.device("cpu")
|
||||||
|
model = AccidentalCNN().to(device)
|
||||||
|
model.load_state_dict(torch.load(MODEL_PATH, weights_only=True))
|
||||||
|
model.eval()
|
||||||
|
print(f"\nLoaded weights from {MODEL_PATH}")
|
||||||
|
|
||||||
|
# Run inference
|
||||||
|
test_ds = InMemoryDataset(images[test_idx], labels[test_idx], augment=False)
|
||||||
|
test_loader = DataLoader(test_ds, batch_size=256, shuffle=False, num_workers=0)
|
||||||
|
|
||||||
|
all_preds = []
|
||||||
|
all_labels = []
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for imgs, lbls in test_loader:
|
||||||
|
imgs = imgs.to(device)
|
||||||
|
outputs = model(imgs)
|
||||||
|
_, predicted = outputs.max(1)
|
||||||
|
all_preds.append(predicted.cpu())
|
||||||
|
all_labels.append(lbls)
|
||||||
|
|
||||||
|
all_preds = torch.cat(all_preds)
|
||||||
|
all_labels = torch.cat(all_labels)
|
||||||
|
|
||||||
|
# Overall accuracy
|
||||||
|
correct = (all_preds == all_labels).sum().item()
|
||||||
|
total = len(all_labels)
|
||||||
|
print(f"\n=== Test Results ===")
|
||||||
|
print(f"Overall accuracy: {correct}/{total} ({correct/total:.1%})")
|
||||||
|
|
||||||
|
# Per-class accuracy
|
||||||
|
print(f"\nPer-class accuracy:")
|
||||||
|
for i, name in enumerate(LABELS):
|
||||||
|
mask = all_labels == i
|
||||||
|
cls_total = mask.sum().item()
|
||||||
|
cls_correct = (all_preds[mask] == i).sum().item()
|
||||||
|
acc = cls_correct / cls_total if cls_total > 0 else 0
|
||||||
|
print(f" {name:8s}: {cls_correct}/{cls_total} ({acc:.1%})")
|
||||||
|
|
||||||
|
# Confusion matrix
|
||||||
|
n_classes = len(LABELS)
|
||||||
|
cm = torch.zeros(n_classes, n_classes, dtype=torch.long)
|
||||||
|
for t, p in zip(all_labels, all_preds):
|
||||||
|
cm[t][p] += 1
|
||||||
|
|
||||||
|
print(f"\nConfusion matrix (rows=true, cols=predicted):")
|
||||||
|
header = " " + "".join(f"{name:>9s}" for name in LABELS)
|
||||||
|
print(header)
|
||||||
|
for i, name in enumerate(LABELS):
|
||||||
|
row = f" {name:8s}" + "".join(f"{cm[i][j].item():9d}" for j in range(n_classes))
|
||||||
|
print(row)
|
||||||
|
|
||||||
|
# Sample misclassifications
|
||||||
|
test_indices = test_idx.tolist()
|
||||||
|
misclassified = []
|
||||||
|
for i in range(len(all_preds)):
|
||||||
|
if all_preds[i] != all_labels[i]:
|
||||||
|
orig_idx = test_indices[i]
|
||||||
|
misclassified.append((
|
||||||
|
file_paths[orig_idx],
|
||||||
|
LABELS[all_labels[i].item()],
|
||||||
|
LABELS[all_preds[i].item()],
|
||||||
|
))
|
||||||
|
|
||||||
|
print(f"\nMisclassifications: {len(misclassified)}/{total}")
|
||||||
|
if misclassified:
|
||||||
|
print(f"\nSample misclassifications (up to 10):")
|
||||||
|
for path, true_lbl, pred_lbl in misclassified[:10]:
|
||||||
|
print(f" {path}")
|
||||||
|
print(f" true={true_lbl} predicted={pred_lbl}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
102
explore_dataset.py
Normal file
102
explore_dataset.py
Normal file
|
|
@ -0,0 +1,102 @@
|
||||||
|
"""Explore the PrIMuS dataset to understand accidental distribution and image structure."""
|
||||||
|
|
||||||
|
import glob
|
||||||
|
import os
|
||||||
|
from collections import Counter
|
||||||
|
from PIL import Image
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
DATASET_ROOT = r"C:\src\accidentals\dataset"
|
||||||
|
|
||||||
|
|
||||||
|
def parse_agnostic(path: str) -> list[str]:
|
||||||
|
"""Parse an agnostic encoding file into a list of tokens."""
|
||||||
|
with open(path, "r", encoding="utf-8") as f:
|
||||||
|
return f.read().strip().split("\t")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# Find all agnostic files
|
||||||
|
patterns = [
|
||||||
|
os.path.join(DATASET_ROOT, "package_aa", "*", "*.agnostic"),
|
||||||
|
os.path.join(DATASET_ROOT, "package_ab", "*", "*.agnostic"),
|
||||||
|
]
|
||||||
|
agnostic_files = []
|
||||||
|
for p in patterns:
|
||||||
|
agnostic_files.extend(glob.glob(p))
|
||||||
|
|
||||||
|
print(f"Total incipits: {len(agnostic_files)}")
|
||||||
|
|
||||||
|
# Count accidental tokens
|
||||||
|
accidental_type_counts = Counter() # sharp/flat/natural
|
||||||
|
accidental_full_counts = Counter() # full token like accidental.sharp-L5
|
||||||
|
incipits_with_accidentals = 0
|
||||||
|
incipits_with_inline_accidentals = 0 # accidentals that aren't in key sig
|
||||||
|
all_symbol_types = Counter()
|
||||||
|
total_accidentals = 0
|
||||||
|
|
||||||
|
for path in agnostic_files:
|
||||||
|
tokens = parse_agnostic(path)
|
||||||
|
has_any_accidental = False
|
||||||
|
has_inline = False
|
||||||
|
past_time_sig = False
|
||||||
|
|
||||||
|
for tok in tokens:
|
||||||
|
# Track symbol types (just the prefix)
|
||||||
|
base = tok.split("-")[0] if "-" in tok else tok
|
||||||
|
all_symbol_types[base] += 1
|
||||||
|
|
||||||
|
if tok.startswith("digit."):
|
||||||
|
past_time_sig = True
|
||||||
|
|
||||||
|
if tok.startswith("accidental."):
|
||||||
|
has_any_accidental = True
|
||||||
|
total_accidentals += 1
|
||||||
|
# Extract type: accidental.sharp, accidental.flat, etc.
|
||||||
|
acc_type = tok.split("-")[0] # e.g. "accidental.sharp"
|
||||||
|
accidental_type_counts[acc_type] += 1
|
||||||
|
accidental_full_counts[tok] += 1
|
||||||
|
|
||||||
|
if past_time_sig:
|
||||||
|
has_inline = True
|
||||||
|
|
||||||
|
if has_any_accidental:
|
||||||
|
incipits_with_accidentals += 1
|
||||||
|
if has_inline:
|
||||||
|
incipits_with_inline_accidentals += 1
|
||||||
|
|
||||||
|
print(f"\n=== Accidental Statistics ===")
|
||||||
|
print(f"Total accidental tokens: {total_accidentals}")
|
||||||
|
print(f"Incipits with any accidentals: {incipits_with_accidentals} / {len(agnostic_files)} ({100*incipits_with_accidentals/len(agnostic_files):.1f}%)")
|
||||||
|
print(f"Incipits with inline accidentals: {incipits_with_inline_accidentals} / {len(agnostic_files)} ({100*incipits_with_inline_accidentals/len(agnostic_files):.1f}%)")
|
||||||
|
|
||||||
|
print(f"\n=== Accidental Type Counts ===")
|
||||||
|
for acc_type, count in accidental_type_counts.most_common():
|
||||||
|
print(f" {acc_type:25s} {count:7d}")
|
||||||
|
|
||||||
|
print(f"\n=== Top 20 Accidental Positions ===")
|
||||||
|
for tok, count in accidental_full_counts.most_common(20):
|
||||||
|
print(f" {tok:30s} {count:7d}")
|
||||||
|
|
||||||
|
print(f"\n=== Top 30 Symbol Types ===")
|
||||||
|
for sym, count in all_symbol_types.most_common(30):
|
||||||
|
print(f" {sym:30s} {count:7d}")
|
||||||
|
|
||||||
|
# Image statistics from a sample
|
||||||
|
print(f"\n=== Image Statistics (sample of 500) ===")
|
||||||
|
png_files = glob.glob(os.path.join(DATASET_ROOT, "package_aa", "*", "*.png"))[:500]
|
||||||
|
widths, heights = [], []
|
||||||
|
for f in png_files:
|
||||||
|
im = Image.open(f)
|
||||||
|
widths.append(im.size[0])
|
||||||
|
heights.append(im.size[1])
|
||||||
|
|
||||||
|
widths = np.array(widths)
|
||||||
|
heights = np.array(heights)
|
||||||
|
print(f" Width: min={widths.min()}, max={widths.max()}, mean={widths.mean():.0f}, std={widths.std():.0f}")
|
||||||
|
print(f" Height: min={heights.min()}, max={heights.max()}, mean={heights.mean():.0f}, std={heights.std():.0f}")
|
||||||
|
print(f" Modes: {Counter(Image.open(f).mode for f in png_files[:50])}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
214
extract_accidentals.py
Normal file
214
extract_accidentals.py
Normal file
|
|
@ -0,0 +1,214 @@
|
||||||
|
"""Extract accidental crops from the PrIMuS dataset using Verovio re-rendering.
|
||||||
|
|
||||||
|
For each MEI file:
|
||||||
|
1. Render with Verovio to SVG (getting exact symbol positions)
|
||||||
|
2. Rasterize SVG to PNG with cairosvg
|
||||||
|
3. Parse SVG for accidental elements (sharp/flat/natural)
|
||||||
|
4. Crop each accidental from the rasterized image
|
||||||
|
5. Save organized by type: crops/{sharp,flat,natural}/
|
||||||
|
|
||||||
|
SMuFL glyph IDs:
|
||||||
|
E262 = sharp
|
||||||
|
E260 = flat
|
||||||
|
E261 = natural
|
||||||
|
"""
|
||||||
|
|
||||||
|
import glob
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import sys
|
||||||
|
import traceback
|
||||||
|
from io import BytesIO
|
||||||
|
|
||||||
|
import cairosvg
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
sys.stdout.reconfigure(encoding="utf-8")
|
||||||
|
|
||||||
|
DATASET_ROOT = r"C:\src\accidentals\dataset"
|
||||||
|
OUTPUT_ROOT = r"C:\src\accidentals\crops"
|
||||||
|
|
||||||
|
GLYPH_MAP = {"E262": "sharp", "E260": "flat", "E261": "natural"}
|
||||||
|
|
||||||
|
# Regex to find accidental <use> elements in Verovio SVG
|
||||||
|
ACCID_RE = re.compile(
|
||||||
|
r'class="(keyAccid|accid)"[^>]*>\s*'
|
||||||
|
r'<use xlink:href="#(E\d+)[^"]*"\s+'
|
||||||
|
r'transform="translate\((\d+),\s*(\d+)\)\s*scale\(([^,]+),\s*([^)]+)\)"'
|
||||||
|
)
|
||||||
|
SVG_DIM_RE = re.compile(r'width="(\d+)px" height="(\d+)px"')
|
||||||
|
|
||||||
|
# Crop parameters (in SVG pixel units, viewBox/10)
|
||||||
|
CROP_PAD_LEFT = 3
|
||||||
|
CROP_PAD_RIGHT = 22
|
||||||
|
CROP_PAD_TOP = 25
|
||||||
|
CROP_PAD_BOTTOM = 35
|
||||||
|
|
||||||
|
|
||||||
|
def setup_verovio():
|
||||||
|
import verovio
|
||||||
|
|
||||||
|
tk = verovio.toolkit()
|
||||||
|
tk.setOptions(
|
||||||
|
{
|
||||||
|
"adjustPageHeight": True,
|
||||||
|
"adjustPageWidth": True,
|
||||||
|
"header": "none",
|
||||||
|
"footer": "none",
|
||||||
|
"breaks": "none",
|
||||||
|
"scale": 100,
|
||||||
|
"pageMarginLeft": 0,
|
||||||
|
"pageMarginRight": 0,
|
||||||
|
"pageMarginTop": 0,
|
||||||
|
"pageMarginBottom": 0,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return tk
|
||||||
|
|
||||||
|
|
||||||
|
def extract_from_mei(tk, mei_path: str) -> list[dict]:
|
||||||
|
"""Extract accidental crops from a single MEI file.
|
||||||
|
|
||||||
|
Returns list of dicts with 'type', 'image' (PIL Image), 'context' (keyAccid/accid).
|
||||||
|
"""
|
||||||
|
with open(mei_path, "r", encoding="utf-8") as f:
|
||||||
|
mei_data = f.read()
|
||||||
|
|
||||||
|
tk.loadData(mei_data)
|
||||||
|
svg = tk.renderToSVG(1)
|
||||||
|
|
||||||
|
if not svg:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Get SVG pixel dimensions
|
||||||
|
dim_match = SVG_DIM_RE.search(svg)
|
||||||
|
if not dim_match:
|
||||||
|
return []
|
||||||
|
svg_w = int(dim_match.group(1))
|
||||||
|
svg_h = int(dim_match.group(2))
|
||||||
|
|
||||||
|
if svg_w == 0 or svg_h == 0:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Find accidentals in SVG
|
||||||
|
accid_positions = []
|
||||||
|
for match in ACCID_RE.finditer(svg):
|
||||||
|
cls = match.group(1)
|
||||||
|
glyph_id = match.group(2)
|
||||||
|
tx, ty = int(match.group(3)), int(match.group(4))
|
||||||
|
|
||||||
|
glyph_type = GLYPH_MAP.get(glyph_id)
|
||||||
|
if glyph_type is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# viewBox coords -> pixel coords (viewBox is 10x pixel dimensions)
|
||||||
|
px = tx / 10.0
|
||||||
|
py = ty / 10.0
|
||||||
|
|
||||||
|
accid_positions.append(
|
||||||
|
{"type": glyph_type, "context": cls, "px": px, "py": py}
|
||||||
|
)
|
||||||
|
|
||||||
|
if not accid_positions:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Rasterize SVG to PNG
|
||||||
|
try:
|
||||||
|
png_bytes = cairosvg.svg2png(
|
||||||
|
bytestring=svg.encode("utf-8"),
|
||||||
|
output_width=svg_w,
|
||||||
|
output_height=svg_h,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
return []
|
||||||
|
|
||||||
|
img = Image.open(BytesIO(png_bytes)).convert("L")
|
||||||
|
|
||||||
|
# Crop each accidental
|
||||||
|
results = []
|
||||||
|
for acc in accid_positions:
|
||||||
|
x1 = max(0, int(acc["px"] - CROP_PAD_LEFT))
|
||||||
|
y1 = max(0, int(acc["py"] - CROP_PAD_TOP))
|
||||||
|
x2 = min(svg_w, int(acc["px"] + CROP_PAD_RIGHT))
|
||||||
|
y2 = min(svg_h, int(acc["py"] + CROP_PAD_BOTTOM))
|
||||||
|
|
||||||
|
if x2 - x1 < 5 or y2 - y1 < 5:
|
||||||
|
continue
|
||||||
|
|
||||||
|
crop = img.crop((x1, y1, x2, y2))
|
||||||
|
results.append(
|
||||||
|
{
|
||||||
|
"type": acc["type"],
|
||||||
|
"context": acc["context"],
|
||||||
|
"image": crop,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# Find all MEI files
|
||||||
|
mei_files = []
|
||||||
|
for pkg in ["package_aa", "package_ab"]:
|
||||||
|
mei_files.extend(
|
||||||
|
glob.glob(os.path.join(DATASET_ROOT, pkg, "*", "*.mei"))
|
||||||
|
)
|
||||||
|
mei_files.sort()
|
||||||
|
|
||||||
|
print(f"Found {len(mei_files)} MEI files")
|
||||||
|
|
||||||
|
# Create output directories
|
||||||
|
for label in GLYPH_MAP.values():
|
||||||
|
os.makedirs(os.path.join(OUTPUT_ROOT, label), exist_ok=True)
|
||||||
|
|
||||||
|
tk = setup_verovio()
|
||||||
|
|
||||||
|
# Process with multiple Verovio fonts for variety
|
||||||
|
fonts = ["Leipzig", "Bravura", "Gootville"]
|
||||||
|
counts = {"sharp": 0, "flat": 0, "natural": 0}
|
||||||
|
errors = 0
|
||||||
|
total_processed = 0
|
||||||
|
|
||||||
|
for font_idx, font in enumerate(fonts):
|
||||||
|
tk.setOptions({"font": font})
|
||||||
|
print(f"\n=== Font: {font} (pass {font_idx+1}/{len(fonts)}) ===")
|
||||||
|
|
||||||
|
for i, mei_path in enumerate(mei_files):
|
||||||
|
if i % 5000 == 0:
|
||||||
|
print(
|
||||||
|
f" [{font}] {i}/{len(mei_files)} "
|
||||||
|
f"sharp={counts['sharp']} flat={counts['flat']} "
|
||||||
|
f"natural={counts['natural']} errors={errors}"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
results = extract_from_mei(tk, mei_path)
|
||||||
|
except Exception:
|
||||||
|
errors += 1
|
||||||
|
if errors <= 5:
|
||||||
|
traceback.print_exc()
|
||||||
|
continue
|
||||||
|
|
||||||
|
sample_id = os.path.basename(os.path.dirname(mei_path))
|
||||||
|
for j, res in enumerate(results):
|
||||||
|
label = res["type"]
|
||||||
|
fname = f"{sample_id}_f{font_idx}_{j}.png"
|
||||||
|
out_path = os.path.join(OUTPUT_ROOT, label, fname)
|
||||||
|
res["image"].save(out_path)
|
||||||
|
counts[label] += 1
|
||||||
|
|
||||||
|
total_processed += 1
|
||||||
|
|
||||||
|
print(f"\n=== Done ===")
|
||||||
|
print(f"Processed: {total_processed} MEI files across {len(fonts)} fonts")
|
||||||
|
print(f"Errors: {errors}")
|
||||||
|
print(f"Extracted crops:")
|
||||||
|
for label, count in sorted(counts.items()):
|
||||||
|
print(f" {label:10s}: {count:7d}")
|
||||||
|
print(f"Output: {OUTPUT_ROOT}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
161
extract_fast.py
Normal file
161
extract_fast.py
Normal file
|
|
@ -0,0 +1,161 @@
|
||||||
|
"""Fast accidental extraction from PrIMuS using Verovio + cairosvg.
|
||||||
|
|
||||||
|
Optimizations:
|
||||||
|
- Pre-filter MEI files using agnostic encoding (skip files with no accidentals)
|
||||||
|
- Single font pass (Leipzig) - augmentation can be done at training time
|
||||||
|
- Flush output for progress monitoring
|
||||||
|
"""
|
||||||
|
|
||||||
|
import glob
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
import traceback
|
||||||
|
from io import BytesIO
|
||||||
|
|
||||||
|
import cairosvg
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
sys.stdout.reconfigure(encoding="utf-8")
|
||||||
|
|
||||||
|
DATASET_ROOT = r"C:\src\accidentals\dataset"
|
||||||
|
OUTPUT_ROOT = r"C:\src\accidentals\crops"
|
||||||
|
|
||||||
|
GLYPH_MAP = {"E262": "sharp", "E260": "flat", "E261": "natural"}
|
||||||
|
|
||||||
|
ACCID_RE = re.compile(
|
||||||
|
r'class="(keyAccid|accid)"[^>]*>\s*'
|
||||||
|
r'<use xlink:href="#(E\d+)[^"]*"\s+'
|
||||||
|
r'transform="translate\((\d+),\s*(\d+)\)\s*scale\(([^,]+),\s*([^)]+)\)"'
|
||||||
|
)
|
||||||
|
SVG_DIM_RE = re.compile(r'width="(\d+)px" height="(\d+)px"')
|
||||||
|
|
||||||
|
|
||||||
|
def has_accidentals(agnostic_path: str) -> bool:
|
||||||
|
"""Quick check if an agnostic file contains any accidental tokens."""
|
||||||
|
with open(agnostic_path, "r", encoding="utf-8") as f:
|
||||||
|
return "accidental." in f.read()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
import verovio
|
||||||
|
|
||||||
|
# Glob all MEI files at once (much faster than iterating dirs on Windows)
|
||||||
|
print("Finding MEI files...", flush=True)
|
||||||
|
mei_files = sorted(
|
||||||
|
glob.glob(os.path.join(DATASET_ROOT, "package_*", "*", "*.mei"))
|
||||||
|
)
|
||||||
|
print(f"Found {len(mei_files)} MEI files (skipping pre-filter)", flush=True)
|
||||||
|
|
||||||
|
# Create output directories
|
||||||
|
for label in GLYPH_MAP.values():
|
||||||
|
os.makedirs(os.path.join(OUTPUT_ROOT, label), exist_ok=True)
|
||||||
|
|
||||||
|
# Setup Verovio
|
||||||
|
tk = verovio.toolkit()
|
||||||
|
tk.setOptions({
|
||||||
|
"adjustPageHeight": True,
|
||||||
|
"adjustPageWidth": True,
|
||||||
|
"header": "none",
|
||||||
|
"footer": "none",
|
||||||
|
"breaks": "none",
|
||||||
|
"scale": 100,
|
||||||
|
"pageMarginLeft": 0,
|
||||||
|
"pageMarginRight": 0,
|
||||||
|
"pageMarginTop": 0,
|
||||||
|
"pageMarginBottom": 0,
|
||||||
|
"font": "Leipzig",
|
||||||
|
})
|
||||||
|
|
||||||
|
counts = {"sharp": 0, "flat": 0, "natural": 0}
|
||||||
|
errors = 0
|
||||||
|
t0 = time.time()
|
||||||
|
|
||||||
|
for i, mei_path in enumerate(mei_files):
|
||||||
|
if i % 1000 == 0:
|
||||||
|
elapsed = time.time() - t0
|
||||||
|
rate = i / elapsed if elapsed > 0 else 0
|
||||||
|
eta = (len(mei_files) - i) / rate if rate > 0 else 0
|
||||||
|
print(
|
||||||
|
f"[{i:6d}/{len(mei_files)}] "
|
||||||
|
f"sharp={counts['sharp']} flat={counts['flat']} "
|
||||||
|
f"natural={counts['natural']} "
|
||||||
|
f"err={errors} "
|
||||||
|
f"rate={rate:.0f}/s eta={eta/60:.0f}m",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(mei_path, "r", encoding="utf-8") as f:
|
||||||
|
mei_data = f.read()
|
||||||
|
|
||||||
|
tk.loadData(mei_data)
|
||||||
|
svg = tk.renderToSVG(1)
|
||||||
|
if not svg:
|
||||||
|
continue
|
||||||
|
|
||||||
|
dim_match = SVG_DIM_RE.search(svg)
|
||||||
|
if not dim_match:
|
||||||
|
continue
|
||||||
|
svg_w = int(dim_match.group(1))
|
||||||
|
svg_h = int(dim_match.group(2))
|
||||||
|
if svg_w == 0 or svg_h == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Find accidentals
|
||||||
|
accids = []
|
||||||
|
for match in ACCID_RE.finditer(svg):
|
||||||
|
glyph_type = GLYPH_MAP.get(match.group(2))
|
||||||
|
if glyph_type:
|
||||||
|
accids.append({
|
||||||
|
"type": glyph_type,
|
||||||
|
"px": int(match.group(3)) / 10.0,
|
||||||
|
"py": int(match.group(4)) / 10.0,
|
||||||
|
})
|
||||||
|
|
||||||
|
if not accids:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Rasterize
|
||||||
|
png_bytes = cairosvg.svg2png(
|
||||||
|
bytestring=svg.encode("utf-8"),
|
||||||
|
output_width=svg_w,
|
||||||
|
output_height=svg_h,
|
||||||
|
)
|
||||||
|
img_rgba = Image.open(BytesIO(png_bytes))
|
||||||
|
white_bg = Image.new("RGBA", img_rgba.size, (255, 255, 255, 255))
|
||||||
|
img = Image.alpha_composite(white_bg, img_rgba).convert("L")
|
||||||
|
|
||||||
|
sample_id = os.path.basename(os.path.dirname(mei_path))
|
||||||
|
for j, acc in enumerate(accids):
|
||||||
|
x1 = max(0, int(acc["px"] - 3))
|
||||||
|
y1 = max(0, int(acc["py"] - 25))
|
||||||
|
x2 = min(svg_w, int(acc["px"] + 22))
|
||||||
|
y2 = min(svg_h, int(acc["py"] + 35))
|
||||||
|
if x2 - x1 < 5 or y2 - y1 < 5:
|
||||||
|
continue
|
||||||
|
|
||||||
|
crop = img.crop((x1, y1, x2, y2))
|
||||||
|
fname = f"{sample_id}_{j}.png"
|
||||||
|
crop.save(os.path.join(OUTPUT_ROOT, acc["type"], fname))
|
||||||
|
counts[acc["type"]] += 1
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
errors += 1
|
||||||
|
if errors <= 5:
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
elapsed = time.time() - t0
|
||||||
|
print(f"\n=== Done in {elapsed/60:.1f} minutes ===", flush=True)
|
||||||
|
print(f"Errors: {errors}")
|
||||||
|
print(f"Crops extracted:")
|
||||||
|
for label, count in sorted(counts.items()):
|
||||||
|
print(f" {label:10s}: {count:7d}")
|
||||||
|
print(f"Total: {sum(counts.values())}")
|
||||||
|
print(f"Output: {OUTPUT_ROOT}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
BIN
model/accidental_classifier.pt
Normal file
BIN
model/accidental_classifier.pt
Normal file
Binary file not shown.
122
segment_test.py
Normal file
122
segment_test.py
Normal file
|
|
@ -0,0 +1,122 @@
|
||||||
|
"""Test segmenting symbols from a PrIMuS image using connected components."""
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
|
def segment_symbols(png_path: str):
|
||||||
|
"""Segment a PrIMuS image into individual symbols."""
|
||||||
|
# Load as grayscale
|
||||||
|
img = cv2.imread(png_path, cv2.IMREAD_GRAYSCALE)
|
||||||
|
if img is None:
|
||||||
|
# PrIMuS uses palette mode PNGs, load via PIL first
|
||||||
|
pil_img = Image.open(png_path).convert("L")
|
||||||
|
img = np.array(pil_img)
|
||||||
|
|
||||||
|
h, w = img.shape
|
||||||
|
print(f"Image size: {w} x {h}")
|
||||||
|
|
||||||
|
# Binarize (black ink on white background)
|
||||||
|
_, binary = cv2.threshold(img, 128, 255, cv2.THRESH_BINARY_INV)
|
||||||
|
|
||||||
|
# Remove staff lines using horizontal morphology
|
||||||
|
# Detect staff lines
|
||||||
|
horizontal_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (w // 4, 1))
|
||||||
|
staff_lines = cv2.morphologyEx(binary, cv2.MORPH_OPEN, horizontal_kernel)
|
||||||
|
|
||||||
|
# Remove staff lines from binary image
|
||||||
|
no_lines = binary - staff_lines
|
||||||
|
|
||||||
|
# Clean up with small morphological close to reconnect broken symbols
|
||||||
|
close_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (2, 2))
|
||||||
|
no_lines = cv2.morphologyEx(no_lines, cv2.MORPH_CLOSE, close_kernel)
|
||||||
|
|
||||||
|
# Find connected components
|
||||||
|
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(
|
||||||
|
no_lines, connectivity=8
|
||||||
|
)
|
||||||
|
|
||||||
|
# Filter out tiny noise (area < 10 pixels)
|
||||||
|
symbols = []
|
||||||
|
for i in range(1, num_labels): # skip background
|
||||||
|
x, y, sw, sh, area = stats[i]
|
||||||
|
if area < 10:
|
||||||
|
continue
|
||||||
|
symbols.append({
|
||||||
|
"x": x, "y": y, "w": sw, "h": sh,
|
||||||
|
"area": area, "cx": centroids[i][0],
|
||||||
|
})
|
||||||
|
|
||||||
|
# Sort by x position (left to right)
|
||||||
|
symbols.sort(key=lambda s: s["x"])
|
||||||
|
|
||||||
|
# Group nearby components into symbol clusters
|
||||||
|
# (a single symbol like a sharp may have multiple disconnected parts)
|
||||||
|
clusters = []
|
||||||
|
for s in symbols:
|
||||||
|
if clusters and s["x"] < clusters[-1]["x2"] + 5:
|
||||||
|
# Merge into previous cluster
|
||||||
|
c = clusters[-1]
|
||||||
|
c["x1"] = min(c["x1"], s["x"])
|
||||||
|
c["y1"] = min(c["y1"], s["y"])
|
||||||
|
c["x2"] = max(c["x2"], s["x"] + s["w"])
|
||||||
|
c["y2"] = max(c["y2"], s["y"] + s["h"])
|
||||||
|
c["components"].append(s)
|
||||||
|
else:
|
||||||
|
clusters.append({
|
||||||
|
"x1": s["x"], "y1": s["y"],
|
||||||
|
"x2": s["x"] + s["w"], "y2": s["y"] + s["h"],
|
||||||
|
"components": [s],
|
||||||
|
})
|
||||||
|
|
||||||
|
print(f"Found {len(symbols)} components -> {len(clusters)} symbol clusters")
|
||||||
|
for i, c in enumerate(clusters):
|
||||||
|
w = c["x2"] - c["x1"]
|
||||||
|
h = c["y2"] - c["y1"]
|
||||||
|
print(f" Cluster {i:3d}: x={c['x1']:5d}-{c['x2']:5d} y={c['y1']:3d}-{c['y2']:3d} size={w:3d}x{h:3d}")
|
||||||
|
|
||||||
|
return clusters, img, no_lines
|
||||||
|
|
||||||
|
|
||||||
|
def match_tokens_to_clusters(agnostic_path: str, clusters: list):
|
||||||
|
"""Match agnostic tokens to segmented clusters."""
|
||||||
|
with open(agnostic_path, "r") as f:
|
||||||
|
tokens = f.read().strip().split("\t")
|
||||||
|
|
||||||
|
print(f"\nAgnostic tokens ({len(tokens)}):")
|
||||||
|
for i, tok in enumerate(tokens):
|
||||||
|
print(f" {i:3d}: {tok}")
|
||||||
|
|
||||||
|
print(f"\nClusters: {len(clusters)}")
|
||||||
|
print(f"Tokens: {len(tokens)}")
|
||||||
|
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
sample_dir = r"C:\src\accidentals\dataset\package_aa\000141058-1_1_1"
|
||||||
|
sample_id = "000141058-1_1_1"
|
||||||
|
|
||||||
|
png_path = f"{sample_dir}/{sample_id}.png"
|
||||||
|
agnostic_path = f"{sample_dir}/{sample_id}.agnostic"
|
||||||
|
|
||||||
|
clusters, img, no_lines = segment_symbols(png_path)
|
||||||
|
tokens = match_tokens_to_clusters(agnostic_path, clusters)
|
||||||
|
|
||||||
|
# Save debug images
|
||||||
|
debug_dir = r"C:\src\accidentals\debug"
|
||||||
|
import os
|
||||||
|
os.makedirs(debug_dir, exist_ok=True)
|
||||||
|
|
||||||
|
cv2.imwrite(f"{debug_dir}/no_lines.png", 255 - no_lines)
|
||||||
|
|
||||||
|
# Draw bounding boxes on original
|
||||||
|
vis = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
||||||
|
for i, c in enumerate(clusters):
|
||||||
|
color = (0, 0, 255)
|
||||||
|
cv2.rectangle(vis, (c["x1"], c["y1"]), (c["x2"], c["y2"]), color, 1)
|
||||||
|
cv2.putText(vis, str(i), (c["x1"], c["y1"] - 2),
|
||||||
|
cv2.FONT_HERSHEY_SIMPLEX, 0.3, color, 1)
|
||||||
|
cv2.imwrite(f"{debug_dir}/segmented.png", vis)
|
||||||
|
print(f"\nDebug images saved to {debug_dir}/")
|
||||||
315
train.py
Normal file
315
train.py
Normal file
|
|
@ -0,0 +1,315 @@
|
||||||
|
"""Train a small CNN to classify accidentals (sharp/flat/natural).
|
||||||
|
|
||||||
|
Outputs an ONNX model for inference via OpenCV's dnn module.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
import random
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
|
||||||
|
from PIL import Image
|
||||||
|
from torchvision import transforms
|
||||||
|
|
||||||
|
sys.stdout.reconfigure(encoding="utf-8")
|
||||||
|
|
||||||
|
CROPS_ROOT = Path(r"C:\src\accidentals\crops")
|
||||||
|
OUTPUT_DIR = Path(r"C:\src\accidentals\model")
|
||||||
|
LABELS = ["flat", "natural", "sharp"] # alphabetical → class indices 0,1,2
|
||||||
|
IMG_SIZE = (40, 40) # square, small enough for fast inference
|
||||||
|
BATCH_SIZE = 256
|
||||||
|
NUM_EPOCHS = 30
|
||||||
|
LR = 1e-3
|
||||||
|
SEED = 42
|
||||||
|
VAL_SPLIT = 0.1
|
||||||
|
TEST_SPLIT = 0.05
|
||||||
|
MAX_PER_CLASS = 5000 # subsample: more than enough for this easy problem
|
||||||
|
|
||||||
|
|
||||||
|
class InMemoryDataset(Dataset):
|
||||||
|
"""Dataset that holds all images as tensors in RAM."""
|
||||||
|
|
||||||
|
def __init__(self, images: torch.Tensor, labels: torch.Tensor, augment=False):
|
||||||
|
self.images = images # (N, 1, H, W) float32
|
||||||
|
self.labels = labels # (N,) long
|
||||||
|
self.augment = augment
|
||||||
|
self.affine = transforms.RandomAffine(
|
||||||
|
degrees=5, translate=(0.1, 0.1), scale=(0.9, 1.1)
|
||||||
|
)
|
||||||
|
self.erasing = transforms.RandomErasing(p=0.2, scale=(0.02, 0.08))
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.labels)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
img = self.images[idx]
|
||||||
|
if self.augment:
|
||||||
|
img = self.affine(img)
|
||||||
|
img = self.erasing(img)
|
||||||
|
return img, self.labels[idx]
|
||||||
|
|
||||||
|
|
||||||
|
class AccidentalCNN(nn.Module):
|
||||||
|
"""Small CNN for 3-class accidental classification.
|
||||||
|
|
||||||
|
Input: 1x40x40 grayscale image
|
||||||
|
Architecture: 3 conv blocks + 1 FC layer
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.features = nn.Sequential(
|
||||||
|
# Block 1: 1 -> 16 channels, 40x40 -> 20x20
|
||||||
|
nn.Conv2d(1, 16, 3, padding=1),
|
||||||
|
nn.BatchNorm2d(16),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.MaxPool2d(2),
|
||||||
|
# Block 2: 16 -> 32 channels, 20x20 -> 10x10
|
||||||
|
nn.Conv2d(16, 32, 3, padding=1),
|
||||||
|
nn.BatchNorm2d(32),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.MaxPool2d(2),
|
||||||
|
# Block 3: 32 -> 64 channels, 10x10 -> 5x5
|
||||||
|
nn.Conv2d(32, 64, 3, padding=1),
|
||||||
|
nn.BatchNorm2d(64),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.MaxPool2d(2),
|
||||||
|
)
|
||||||
|
self.classifier = nn.Sequential(
|
||||||
|
nn.Dropout(0.3),
|
||||||
|
nn.Linear(64 * 5 * 5, 64),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.Dropout(0.3),
|
||||||
|
nn.Linear(64, len(LABELS)),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.features(x)
|
||||||
|
x = x.view(x.size(0), -1)
|
||||||
|
x = self.classifier(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def load_all_into_ram():
|
||||||
|
"""Load all images into RAM as tensors. Returns (images, labels) tensors."""
|
||||||
|
print("Loading all images into RAM...", flush=True)
|
||||||
|
all_images = []
|
||||||
|
all_labels = []
|
||||||
|
|
||||||
|
for label_idx, label_name in enumerate(LABELS):
|
||||||
|
label_dir = CROPS_ROOT / label_name
|
||||||
|
files = [f for f in os.listdir(label_dir) if f.endswith(".png")]
|
||||||
|
random.shuffle(files)
|
||||||
|
if MAX_PER_CLASS and len(files) > MAX_PER_CLASS:
|
||||||
|
files = files[:MAX_PER_CLASS]
|
||||||
|
print(f" {label_name}: {len(files)} files", flush=True)
|
||||||
|
|
||||||
|
for fname in files:
|
||||||
|
img = Image.open(label_dir / fname).convert("L")
|
||||||
|
img = img.resize(IMG_SIZE, Image.BILINEAR)
|
||||||
|
arr = np.array(img, dtype=np.float32) / 255.0
|
||||||
|
all_images.append(arr)
|
||||||
|
all_labels.append(label_idx)
|
||||||
|
|
||||||
|
# Stack into tensors: (N, 1, H, W)
|
||||||
|
images = torch.tensor(np.array(all_images)).unsqueeze(1)
|
||||||
|
labels = torch.tensor(all_labels, dtype=torch.long)
|
||||||
|
print(f" Total: {len(labels)} images, tensor shape: {images.shape}", flush=True)
|
||||||
|
print(f" Memory: {images.nbytes / 1024**2:.0f} MB", flush=True)
|
||||||
|
return images, labels
|
||||||
|
|
||||||
|
|
||||||
|
def split_data(images, labels):
|
||||||
|
"""Stratified train/val/test split. Returns index arrays."""
|
||||||
|
indices = torch.randperm(len(labels))
|
||||||
|
|
||||||
|
by_class = {i: [] for i in range(len(LABELS))}
|
||||||
|
for idx in indices:
|
||||||
|
by_class[labels[idx].item()].append(idx.item())
|
||||||
|
|
||||||
|
train_idx, val_idx, test_idx = [], [], []
|
||||||
|
for cls, idxs in by_class.items():
|
||||||
|
n = len(idxs)
|
||||||
|
n_test = max(1, int(n * TEST_SPLIT))
|
||||||
|
n_val = max(1, int(n * VAL_SPLIT))
|
||||||
|
test_idx.extend(idxs[:n_test])
|
||||||
|
val_idx.extend(idxs[n_test : n_test + n_val])
|
||||||
|
train_idx.extend(idxs[n_test + n_val :])
|
||||||
|
|
||||||
|
return (
|
||||||
|
torch.tensor(train_idx),
|
||||||
|
torch.tensor(val_idx),
|
||||||
|
torch.tensor(test_idx),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def make_weighted_sampler(labels):
|
||||||
|
"""Create a sampler that oversamples minority classes."""
|
||||||
|
class_counts = torch.bincount(labels, minlength=len(LABELS)).float()
|
||||||
|
weights = 1.0 / class_counts[labels]
|
||||||
|
return WeightedRandomSampler(weights, len(labels), replacement=True)
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate(model, loader, device):
|
||||||
|
"""Evaluate model on a data loader."""
|
||||||
|
model.eval()
|
||||||
|
correct = 0
|
||||||
|
total = 0
|
||||||
|
class_correct = np.zeros(len(LABELS))
|
||||||
|
class_total = np.zeros(len(LABELS))
|
||||||
|
total_loss = 0.0
|
||||||
|
n_batches = 0
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for imgs, lbls in loader:
|
||||||
|
imgs, lbls = imgs.to(device), lbls.to(device)
|
||||||
|
outputs = model(imgs)
|
||||||
|
loss = F.cross_entropy(outputs, lbls)
|
||||||
|
total_loss += loss.item()
|
||||||
|
n_batches += 1
|
||||||
|
|
||||||
|
_, predicted = outputs.max(1)
|
||||||
|
total += lbls.size(0)
|
||||||
|
correct += predicted.eq(lbls).sum().item()
|
||||||
|
|
||||||
|
for i in range(len(LABELS)):
|
||||||
|
mask = lbls == i
|
||||||
|
class_total[i] += mask.sum().item()
|
||||||
|
class_correct[i] += (predicted[mask] == i).sum().item()
|
||||||
|
|
||||||
|
acc = correct / total if total > 0 else 0
|
||||||
|
class_acc = np.where(class_total > 0, class_correct / class_total, 0)
|
||||||
|
avg_loss = total_loss / n_batches if n_batches > 0 else 0
|
||||||
|
return avg_loss, acc, class_acc
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
random.seed(SEED)
|
||||||
|
np.random.seed(SEED)
|
||||||
|
torch.manual_seed(SEED)
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
print(f"Device: {device}", flush=True)
|
||||||
|
|
||||||
|
# Load everything into RAM
|
||||||
|
images, labels = load_all_into_ram()
|
||||||
|
|
||||||
|
# Split
|
||||||
|
train_idx, val_idx, test_idx = split_data(images, labels)
|
||||||
|
print(f"\nTrain: {len(train_idx)} Val: {len(val_idx)} Test: {len(test_idx)}", flush=True)
|
||||||
|
for i, name in enumerate(LABELS):
|
||||||
|
nt = (labels[train_idx] == i).sum().item()
|
||||||
|
nv = (labels[val_idx] == i).sum().item()
|
||||||
|
ne = (labels[test_idx] == i).sum().item()
|
||||||
|
print(f" {name:8s}: train={nt:6d} val={nv:5d} test={ne:5d}", flush=True)
|
||||||
|
|
||||||
|
# Create datasets (in-memory, fast)
|
||||||
|
train_ds = InMemoryDataset(images[train_idx], labels[train_idx], augment=True)
|
||||||
|
val_ds = InMemoryDataset(images[val_idx], labels[val_idx], augment=False)
|
||||||
|
test_ds = InMemoryDataset(images[test_idx], labels[test_idx], augment=False)
|
||||||
|
|
||||||
|
# Weighted sampler for class imbalance
|
||||||
|
sampler = make_weighted_sampler(labels[train_idx])
|
||||||
|
|
||||||
|
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, sampler=sampler, num_workers=0)
|
||||||
|
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
|
||||||
|
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
|
||||||
|
|
||||||
|
# Model
|
||||||
|
model = AccidentalCNN().to(device)
|
||||||
|
param_count = sum(p.numel() for p in model.parameters())
|
||||||
|
print(f"\nModel parameters: {param_count:,}", flush=True)
|
||||||
|
|
||||||
|
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
|
||||||
|
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
||||||
|
optimizer, mode="max", factor=0.5, patience=3
|
||||||
|
)
|
||||||
|
|
||||||
|
# Training loop
|
||||||
|
best_val_acc = 0.0
|
||||||
|
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
best_model_path = OUTPUT_DIR / "accidental_classifier.pt"
|
||||||
|
|
||||||
|
print(f"\n{'Epoch':>5s} {'TrLoss':>7s} {'VaLoss':>7s} {'VaAcc':>6s} "
|
||||||
|
f"{'Flat':>5s} {'Nat':>5s} {'Sharp':>5s} {'LR':>8s} {'Time':>5s}", flush=True)
|
||||||
|
print("-" * 70, flush=True)
|
||||||
|
|
||||||
|
for epoch in range(1, NUM_EPOCHS + 1):
|
||||||
|
t0 = time.time()
|
||||||
|
model.train()
|
||||||
|
train_loss = 0.0
|
||||||
|
n_batches = 0
|
||||||
|
|
||||||
|
for imgs, lbls in train_loader:
|
||||||
|
imgs, lbls = imgs.to(device), lbls.to(device)
|
||||||
|
optimizer.zero_grad()
|
||||||
|
outputs = model(imgs)
|
||||||
|
loss = F.cross_entropy(outputs, lbls)
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
train_loss += loss.item()
|
||||||
|
n_batches += 1
|
||||||
|
|
||||||
|
avg_train_loss = train_loss / n_batches
|
||||||
|
val_loss, val_acc, class_acc = evaluate(model, val_loader, device)
|
||||||
|
scheduler.step(val_acc)
|
||||||
|
|
||||||
|
elapsed = time.time() - t0
|
||||||
|
lr = optimizer.param_groups[0]["lr"]
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"{epoch:5d} {avg_train_loss:7.4f} {val_loss:7.4f} {val_acc:6.1%} "
|
||||||
|
f"{class_acc[0]:5.1%} {class_acc[1]:5.1%} {class_acc[2]:5.1%} "
|
||||||
|
f"{lr:.6f} {elapsed:5.1f}s",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if val_acc > best_val_acc:
|
||||||
|
best_val_acc = val_acc
|
||||||
|
torch.save(model.state_dict(), best_model_path)
|
||||||
|
|
||||||
|
# Final evaluation on test set
|
||||||
|
print(f"\n=== Test Set Evaluation ===", flush=True)
|
||||||
|
model.load_state_dict(torch.load(best_model_path, weights_only=True))
|
||||||
|
test_loss, test_acc, test_class_acc = evaluate(model, test_loader, device)
|
||||||
|
print(f"Test accuracy: {test_acc:.1%}", flush=True)
|
||||||
|
for i, name in enumerate(LABELS):
|
||||||
|
print(f" {name:8s}: {test_class_acc[i]:.1%}", flush=True)
|
||||||
|
|
||||||
|
# Export to ONNX for OpenCV dnn inference
|
||||||
|
print(f"\nExporting to ONNX...", flush=True)
|
||||||
|
model.eval()
|
||||||
|
dummy = torch.randn(1, 1, *IMG_SIZE)
|
||||||
|
onnx_path = OUTPUT_DIR / "accidental_classifier.onnx"
|
||||||
|
torch.onnx.export(
|
||||||
|
model,
|
||||||
|
dummy,
|
||||||
|
str(onnx_path),
|
||||||
|
input_names=["image"],
|
||||||
|
output_names=["logits"],
|
||||||
|
dynamic_axes={"image": {0: "batch"}, "logits": {0: "batch"}},
|
||||||
|
opset_version=13,
|
||||||
|
)
|
||||||
|
print(f"Saved ONNX model: {onnx_path}", flush=True)
|
||||||
|
print(f"Saved PyTorch model: {best_model_path}", flush=True)
|
||||||
|
|
||||||
|
# Quick ONNX verification via OpenCV
|
||||||
|
import cv2
|
||||||
|
net = cv2.dnn.readNetFromONNX(str(onnx_path))
|
||||||
|
test_img = np.random.randint(0, 255, (IMG_SIZE[1], IMG_SIZE[0]), dtype=np.uint8)
|
||||||
|
blob = cv2.dnn.blobFromImage(test_img, 1.0 / 255.0, IMG_SIZE, 0, swapRB=False)
|
||||||
|
net.setInput(blob)
|
||||||
|
out = net.forward()
|
||||||
|
pred = LABELS[np.argmax(out[0])]
|
||||||
|
print(f"ONNX+OpenCV smoke test: input=random noise -> pred={pred} (expected: arbitrary)", flush=True)
|
||||||
|
print(f"\nDone. Labels: {LABELS}", flush=True)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Loading…
Reference in a new issue