Compare commits
2 commits
a7164caf08
...
85c8cfd8bb
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
85c8cfd8bb | ||
|
|
afc16c2bbb |
9 changed files with 1151 additions and 1 deletions
7
.gitignore
vendored
Normal file
7
.gitignore
vendored
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
__pycache__/
|
||||
.claude/
|
||||
crops/
|
||||
dataset/
|
||||
debug/
|
||||
*.tgz
|
||||
nul
|
||||
83
README.md
83
README.md
|
|
@ -1,2 +1,83 @@
|
|||
# accidentals
|
||||
# Accidental Classifier
|
||||
|
||||
A small CNN that classifies musical accidentals — **sharp**, **flat**, and **natural** — from cropped grayscale images of engraved scores.
|
||||
|
||||
Trained on crops extracted from the [PrIMuS dataset](https://grfia.dlsi.ua.es/primus/) and achieves **100% accuracy** on a held-out test set (750 samples).
|
||||
|
||||
```
|
||||
Confusion matrix (rows=true, cols=predicted):
|
||||
flat natural sharp
|
||||
flat 250 0 0
|
||||
natural 0 250 0
|
||||
sharp 0 0 250
|
||||
```
|
||||
|
||||
## Architecture
|
||||
|
||||
| | |
|
||||
|---|---|
|
||||
| **Input** | 1 × 40 × 40 grayscale |
|
||||
| **Backbone** | 3 conv blocks (16 → 32 → 64 channels), each with BatchNorm + ReLU + MaxPool |
|
||||
| **Head** | Dropout → FC(1600, 64) → ReLU → Dropout → FC(64, 3) |
|
||||
| **Parameters** | ~115 k |
|
||||
| **Output** | 3-class logits (flat / natural / sharp) |
|
||||
|
||||
A saved ONNX export is also produced by `train.py` for inference via OpenCV's `dnn` module.
|
||||
|
||||
## Repository layout
|
||||
|
||||
```
|
||||
train.py Training pipeline (CNN + augmentation + weighted sampling)
|
||||
eval.py Evaluate saved checkpoint on the held-out test set
|
||||
extract_accidentals.py Extract accidental crops from PrIMuS MEI files via Verovio
|
||||
extract_fast.py Faster single-font variant of the extractor
|
||||
explore_dataset.py Explore PrIMuS agnostic encodings and image statistics
|
||||
segment_test.py Connected-component symbol segmentation experiment
|
||||
model/
|
||||
accidental_classifier.pt Saved PyTorch weights (best validation accuracy)
|
||||
```
|
||||
|
||||
## Pipeline
|
||||
|
||||
### 1. Extract crops
|
||||
|
||||
Render each PrIMuS MEI file with [Verovio](https://www.verovio.org/), locate accidental glyphs (SMuFL codepoints E260/E261/E262) in the SVG, rasterize with cairosvg, and crop each symbol into `crops/{flat,natural,sharp}/`.
|
||||
|
||||
```bash
|
||||
python extract_fast.py
|
||||
```
|
||||
|
||||
### 2. Train
|
||||
|
||||
Loads all crops into RAM, applies a stratified train/val/test split (seed 42), trains with data augmentation (random affine + erasing) and class-balanced sampling, and saves the best checkpoint.
|
||||
|
||||
```bash
|
||||
python train.py
|
||||
```
|
||||
|
||||
### 3. Evaluate
|
||||
|
||||
Reproduces the identical test split and reports accuracy, per-class metrics, the confusion matrix, and sample misclassifications.
|
||||
|
||||
```bash
|
||||
python eval.py
|
||||
```
|
||||
|
||||
## Dependencies
|
||||
|
||||
- Python 3.10+
|
||||
- PyTorch
|
||||
- torchvision
|
||||
- NumPy
|
||||
- Pillow
|
||||
- OpenCV (`cv2`) — for ONNX verification and segmentation experiments
|
||||
- Verovio (Python bindings) — for crop extraction only
|
||||
- cairosvg — for crop extraction only
|
||||
|
||||
## Data
|
||||
|
||||
Training data is derived from [PrIMuS](https://grfia.dlsi.ua.es/primus/) (Calvo-Zaragoza & Rizo, 2018). The extracted crops and raw dataset are not included in this repository due to size.
|
||||
|
||||
## License
|
||||
|
||||
This project is released into the public domain.
|
||||
|
|
|
|||
148
eval.py
Normal file
148
eval.py
Normal file
|
|
@ -0,0 +1,148 @@
|
|||
"""Evaluate the saved accidental classifier on the held-out test set.
|
||||
|
||||
Uses the same seed (42), split ratios, and data loading as train.py
|
||||
so the test set is identical to what training used.
|
||||
"""
|
||||
|
||||
import os
|
||||
import random
|
||||
from pathlib import Path
|
||||
from collections import defaultdict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from train import (
|
||||
AccidentalCNN,
|
||||
CROPS_ROOT,
|
||||
OUTPUT_DIR,
|
||||
LABELS,
|
||||
IMG_SIZE,
|
||||
SEED,
|
||||
VAL_SPLIT,
|
||||
TEST_SPLIT,
|
||||
MAX_PER_CLASS,
|
||||
InMemoryDataset,
|
||||
load_all_into_ram,
|
||||
split_data,
|
||||
)
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
MODEL_PATH = OUTPUT_DIR / "accidental_classifier.pt"
|
||||
|
||||
|
||||
def collect_file_paths():
|
||||
"""Collect file paths in the same order as load_all_into_ram.
|
||||
|
||||
Must be called after seeding so random.shuffle matches train.py.
|
||||
"""
|
||||
paths = []
|
||||
for label_idx, label_name in enumerate(LABELS):
|
||||
label_dir = CROPS_ROOT / label_name
|
||||
files = [f for f in os.listdir(label_dir) if f.endswith(".png")]
|
||||
random.shuffle(files)
|
||||
if MAX_PER_CLASS and len(files) > MAX_PER_CLASS:
|
||||
files = files[:MAX_PER_CLASS]
|
||||
for fname in files:
|
||||
paths.append(str(label_dir / fname))
|
||||
return paths
|
||||
|
||||
|
||||
def main():
|
||||
# Reproduce the exact same RNG state and splits as train.py
|
||||
random.seed(SEED)
|
||||
np.random.seed(SEED)
|
||||
torch.manual_seed(SEED)
|
||||
|
||||
# Collect file paths (consumes same random state as load_all_into_ram)
|
||||
file_paths = collect_file_paths()
|
||||
|
||||
# Re-seed and reload to get tensors (load_all_into_ram shuffles again internally)
|
||||
random.seed(SEED)
|
||||
np.random.seed(SEED)
|
||||
torch.manual_seed(SEED)
|
||||
|
||||
images, labels = load_all_into_ram()
|
||||
_, _, test_idx = split_data(images, labels)
|
||||
|
||||
print(f"\nTest set size: {len(test_idx)}")
|
||||
for i, name in enumerate(LABELS):
|
||||
n = (labels[test_idx] == i).sum().item()
|
||||
print(f" {name:8s}: {n}")
|
||||
|
||||
# Load model
|
||||
device = torch.device("cpu")
|
||||
model = AccidentalCNN().to(device)
|
||||
model.load_state_dict(torch.load(MODEL_PATH, weights_only=True))
|
||||
model.eval()
|
||||
print(f"\nLoaded weights from {MODEL_PATH}")
|
||||
|
||||
# Run inference
|
||||
test_ds = InMemoryDataset(images[test_idx], labels[test_idx], augment=False)
|
||||
test_loader = DataLoader(test_ds, batch_size=256, shuffle=False, num_workers=0)
|
||||
|
||||
all_preds = []
|
||||
all_labels = []
|
||||
|
||||
with torch.no_grad():
|
||||
for imgs, lbls in test_loader:
|
||||
imgs = imgs.to(device)
|
||||
outputs = model(imgs)
|
||||
_, predicted = outputs.max(1)
|
||||
all_preds.append(predicted.cpu())
|
||||
all_labels.append(lbls)
|
||||
|
||||
all_preds = torch.cat(all_preds)
|
||||
all_labels = torch.cat(all_labels)
|
||||
|
||||
# Overall accuracy
|
||||
correct = (all_preds == all_labels).sum().item()
|
||||
total = len(all_labels)
|
||||
print(f"\n=== Test Results ===")
|
||||
print(f"Overall accuracy: {correct}/{total} ({correct/total:.1%})")
|
||||
|
||||
# Per-class accuracy
|
||||
print(f"\nPer-class accuracy:")
|
||||
for i, name in enumerate(LABELS):
|
||||
mask = all_labels == i
|
||||
cls_total = mask.sum().item()
|
||||
cls_correct = (all_preds[mask] == i).sum().item()
|
||||
acc = cls_correct / cls_total if cls_total > 0 else 0
|
||||
print(f" {name:8s}: {cls_correct}/{cls_total} ({acc:.1%})")
|
||||
|
||||
# Confusion matrix
|
||||
n_classes = len(LABELS)
|
||||
cm = torch.zeros(n_classes, n_classes, dtype=torch.long)
|
||||
for t, p in zip(all_labels, all_preds):
|
||||
cm[t][p] += 1
|
||||
|
||||
print(f"\nConfusion matrix (rows=true, cols=predicted):")
|
||||
header = " " + "".join(f"{name:>9s}" for name in LABELS)
|
||||
print(header)
|
||||
for i, name in enumerate(LABELS):
|
||||
row = f" {name:8s}" + "".join(f"{cm[i][j].item():9d}" for j in range(n_classes))
|
||||
print(row)
|
||||
|
||||
# Sample misclassifications
|
||||
test_indices = test_idx.tolist()
|
||||
misclassified = []
|
||||
for i in range(len(all_preds)):
|
||||
if all_preds[i] != all_labels[i]:
|
||||
orig_idx = test_indices[i]
|
||||
misclassified.append((
|
||||
file_paths[orig_idx],
|
||||
LABELS[all_labels[i].item()],
|
||||
LABELS[all_preds[i].item()],
|
||||
))
|
||||
|
||||
print(f"\nMisclassifications: {len(misclassified)}/{total}")
|
||||
if misclassified:
|
||||
print(f"\nSample misclassifications (up to 10):")
|
||||
for path, true_lbl, pred_lbl in misclassified[:10]:
|
||||
print(f" {path}")
|
||||
print(f" true={true_lbl} predicted={pred_lbl}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
102
explore_dataset.py
Normal file
102
explore_dataset.py
Normal file
|
|
@ -0,0 +1,102 @@
|
|||
"""Explore the PrIMuS dataset to understand accidental distribution and image structure."""
|
||||
|
||||
import glob
|
||||
import os
|
||||
from collections import Counter
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
DATASET_ROOT = r"C:\src\accidentals\dataset"
|
||||
|
||||
|
||||
def parse_agnostic(path: str) -> list[str]:
|
||||
"""Parse an agnostic encoding file into a list of tokens."""
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
return f.read().strip().split("\t")
|
||||
|
||||
|
||||
def main():
|
||||
# Find all agnostic files
|
||||
patterns = [
|
||||
os.path.join(DATASET_ROOT, "package_aa", "*", "*.agnostic"),
|
||||
os.path.join(DATASET_ROOT, "package_ab", "*", "*.agnostic"),
|
||||
]
|
||||
agnostic_files = []
|
||||
for p in patterns:
|
||||
agnostic_files.extend(glob.glob(p))
|
||||
|
||||
print(f"Total incipits: {len(agnostic_files)}")
|
||||
|
||||
# Count accidental tokens
|
||||
accidental_type_counts = Counter() # sharp/flat/natural
|
||||
accidental_full_counts = Counter() # full token like accidental.sharp-L5
|
||||
incipits_with_accidentals = 0
|
||||
incipits_with_inline_accidentals = 0 # accidentals that aren't in key sig
|
||||
all_symbol_types = Counter()
|
||||
total_accidentals = 0
|
||||
|
||||
for path in agnostic_files:
|
||||
tokens = parse_agnostic(path)
|
||||
has_any_accidental = False
|
||||
has_inline = False
|
||||
past_time_sig = False
|
||||
|
||||
for tok in tokens:
|
||||
# Track symbol types (just the prefix)
|
||||
base = tok.split("-")[0] if "-" in tok else tok
|
||||
all_symbol_types[base] += 1
|
||||
|
||||
if tok.startswith("digit."):
|
||||
past_time_sig = True
|
||||
|
||||
if tok.startswith("accidental."):
|
||||
has_any_accidental = True
|
||||
total_accidentals += 1
|
||||
# Extract type: accidental.sharp, accidental.flat, etc.
|
||||
acc_type = tok.split("-")[0] # e.g. "accidental.sharp"
|
||||
accidental_type_counts[acc_type] += 1
|
||||
accidental_full_counts[tok] += 1
|
||||
|
||||
if past_time_sig:
|
||||
has_inline = True
|
||||
|
||||
if has_any_accidental:
|
||||
incipits_with_accidentals += 1
|
||||
if has_inline:
|
||||
incipits_with_inline_accidentals += 1
|
||||
|
||||
print(f"\n=== Accidental Statistics ===")
|
||||
print(f"Total accidental tokens: {total_accidentals}")
|
||||
print(f"Incipits with any accidentals: {incipits_with_accidentals} / {len(agnostic_files)} ({100*incipits_with_accidentals/len(agnostic_files):.1f}%)")
|
||||
print(f"Incipits with inline accidentals: {incipits_with_inline_accidentals} / {len(agnostic_files)} ({100*incipits_with_inline_accidentals/len(agnostic_files):.1f}%)")
|
||||
|
||||
print(f"\n=== Accidental Type Counts ===")
|
||||
for acc_type, count in accidental_type_counts.most_common():
|
||||
print(f" {acc_type:25s} {count:7d}")
|
||||
|
||||
print(f"\n=== Top 20 Accidental Positions ===")
|
||||
for tok, count in accidental_full_counts.most_common(20):
|
||||
print(f" {tok:30s} {count:7d}")
|
||||
|
||||
print(f"\n=== Top 30 Symbol Types ===")
|
||||
for sym, count in all_symbol_types.most_common(30):
|
||||
print(f" {sym:30s} {count:7d}")
|
||||
|
||||
# Image statistics from a sample
|
||||
print(f"\n=== Image Statistics (sample of 500) ===")
|
||||
png_files = glob.glob(os.path.join(DATASET_ROOT, "package_aa", "*", "*.png"))[:500]
|
||||
widths, heights = [], []
|
||||
for f in png_files:
|
||||
im = Image.open(f)
|
||||
widths.append(im.size[0])
|
||||
heights.append(im.size[1])
|
||||
|
||||
widths = np.array(widths)
|
||||
heights = np.array(heights)
|
||||
print(f" Width: min={widths.min()}, max={widths.max()}, mean={widths.mean():.0f}, std={widths.std():.0f}")
|
||||
print(f" Height: min={heights.min()}, max={heights.max()}, mean={heights.mean():.0f}, std={heights.std():.0f}")
|
||||
print(f" Modes: {Counter(Image.open(f).mode for f in png_files[:50])}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
214
extract_accidentals.py
Normal file
214
extract_accidentals.py
Normal file
|
|
@ -0,0 +1,214 @@
|
|||
"""Extract accidental crops from the PrIMuS dataset using Verovio re-rendering.
|
||||
|
||||
For each MEI file:
|
||||
1. Render with Verovio to SVG (getting exact symbol positions)
|
||||
2. Rasterize SVG to PNG with cairosvg
|
||||
3. Parse SVG for accidental elements (sharp/flat/natural)
|
||||
4. Crop each accidental from the rasterized image
|
||||
5. Save organized by type: crops/{sharp,flat,natural}/
|
||||
|
||||
SMuFL glyph IDs:
|
||||
E262 = sharp
|
||||
E260 = flat
|
||||
E261 = natural
|
||||
"""
|
||||
|
||||
import glob
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import traceback
|
||||
from io import BytesIO
|
||||
|
||||
import cairosvg
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
sys.stdout.reconfigure(encoding="utf-8")
|
||||
|
||||
DATASET_ROOT = r"C:\src\accidentals\dataset"
|
||||
OUTPUT_ROOT = r"C:\src\accidentals\crops"
|
||||
|
||||
GLYPH_MAP = {"E262": "sharp", "E260": "flat", "E261": "natural"}
|
||||
|
||||
# Regex to find accidental <use> elements in Verovio SVG
|
||||
ACCID_RE = re.compile(
|
||||
r'class="(keyAccid|accid)"[^>]*>\s*'
|
||||
r'<use xlink:href="#(E\d+)[^"]*"\s+'
|
||||
r'transform="translate\((\d+),\s*(\d+)\)\s*scale\(([^,]+),\s*([^)]+)\)"'
|
||||
)
|
||||
SVG_DIM_RE = re.compile(r'width="(\d+)px" height="(\d+)px"')
|
||||
|
||||
# Crop parameters (in SVG pixel units, viewBox/10)
|
||||
CROP_PAD_LEFT = 3
|
||||
CROP_PAD_RIGHT = 22
|
||||
CROP_PAD_TOP = 25
|
||||
CROP_PAD_BOTTOM = 35
|
||||
|
||||
|
||||
def setup_verovio():
|
||||
import verovio
|
||||
|
||||
tk = verovio.toolkit()
|
||||
tk.setOptions(
|
||||
{
|
||||
"adjustPageHeight": True,
|
||||
"adjustPageWidth": True,
|
||||
"header": "none",
|
||||
"footer": "none",
|
||||
"breaks": "none",
|
||||
"scale": 100,
|
||||
"pageMarginLeft": 0,
|
||||
"pageMarginRight": 0,
|
||||
"pageMarginTop": 0,
|
||||
"pageMarginBottom": 0,
|
||||
}
|
||||
)
|
||||
return tk
|
||||
|
||||
|
||||
def extract_from_mei(tk, mei_path: str) -> list[dict]:
|
||||
"""Extract accidental crops from a single MEI file.
|
||||
|
||||
Returns list of dicts with 'type', 'image' (PIL Image), 'context' (keyAccid/accid).
|
||||
"""
|
||||
with open(mei_path, "r", encoding="utf-8") as f:
|
||||
mei_data = f.read()
|
||||
|
||||
tk.loadData(mei_data)
|
||||
svg = tk.renderToSVG(1)
|
||||
|
||||
if not svg:
|
||||
return []
|
||||
|
||||
# Get SVG pixel dimensions
|
||||
dim_match = SVG_DIM_RE.search(svg)
|
||||
if not dim_match:
|
||||
return []
|
||||
svg_w = int(dim_match.group(1))
|
||||
svg_h = int(dim_match.group(2))
|
||||
|
||||
if svg_w == 0 or svg_h == 0:
|
||||
return []
|
||||
|
||||
# Find accidentals in SVG
|
||||
accid_positions = []
|
||||
for match in ACCID_RE.finditer(svg):
|
||||
cls = match.group(1)
|
||||
glyph_id = match.group(2)
|
||||
tx, ty = int(match.group(3)), int(match.group(4))
|
||||
|
||||
glyph_type = GLYPH_MAP.get(glyph_id)
|
||||
if glyph_type is None:
|
||||
continue
|
||||
|
||||
# viewBox coords -> pixel coords (viewBox is 10x pixel dimensions)
|
||||
px = tx / 10.0
|
||||
py = ty / 10.0
|
||||
|
||||
accid_positions.append(
|
||||
{"type": glyph_type, "context": cls, "px": px, "py": py}
|
||||
)
|
||||
|
||||
if not accid_positions:
|
||||
return []
|
||||
|
||||
# Rasterize SVG to PNG
|
||||
try:
|
||||
png_bytes = cairosvg.svg2png(
|
||||
bytestring=svg.encode("utf-8"),
|
||||
output_width=svg_w,
|
||||
output_height=svg_h,
|
||||
)
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
img = Image.open(BytesIO(png_bytes)).convert("L")
|
||||
|
||||
# Crop each accidental
|
||||
results = []
|
||||
for acc in accid_positions:
|
||||
x1 = max(0, int(acc["px"] - CROP_PAD_LEFT))
|
||||
y1 = max(0, int(acc["py"] - CROP_PAD_TOP))
|
||||
x2 = min(svg_w, int(acc["px"] + CROP_PAD_RIGHT))
|
||||
y2 = min(svg_h, int(acc["py"] + CROP_PAD_BOTTOM))
|
||||
|
||||
if x2 - x1 < 5 or y2 - y1 < 5:
|
||||
continue
|
||||
|
||||
crop = img.crop((x1, y1, x2, y2))
|
||||
results.append(
|
||||
{
|
||||
"type": acc["type"],
|
||||
"context": acc["context"],
|
||||
"image": crop,
|
||||
}
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def main():
|
||||
# Find all MEI files
|
||||
mei_files = []
|
||||
for pkg in ["package_aa", "package_ab"]:
|
||||
mei_files.extend(
|
||||
glob.glob(os.path.join(DATASET_ROOT, pkg, "*", "*.mei"))
|
||||
)
|
||||
mei_files.sort()
|
||||
|
||||
print(f"Found {len(mei_files)} MEI files")
|
||||
|
||||
# Create output directories
|
||||
for label in GLYPH_MAP.values():
|
||||
os.makedirs(os.path.join(OUTPUT_ROOT, label), exist_ok=True)
|
||||
|
||||
tk = setup_verovio()
|
||||
|
||||
# Process with multiple Verovio fonts for variety
|
||||
fonts = ["Leipzig", "Bravura", "Gootville"]
|
||||
counts = {"sharp": 0, "flat": 0, "natural": 0}
|
||||
errors = 0
|
||||
total_processed = 0
|
||||
|
||||
for font_idx, font in enumerate(fonts):
|
||||
tk.setOptions({"font": font})
|
||||
print(f"\n=== Font: {font} (pass {font_idx+1}/{len(fonts)}) ===")
|
||||
|
||||
for i, mei_path in enumerate(mei_files):
|
||||
if i % 5000 == 0:
|
||||
print(
|
||||
f" [{font}] {i}/{len(mei_files)} "
|
||||
f"sharp={counts['sharp']} flat={counts['flat']} "
|
||||
f"natural={counts['natural']} errors={errors}"
|
||||
)
|
||||
|
||||
try:
|
||||
results = extract_from_mei(tk, mei_path)
|
||||
except Exception:
|
||||
errors += 1
|
||||
if errors <= 5:
|
||||
traceback.print_exc()
|
||||
continue
|
||||
|
||||
sample_id = os.path.basename(os.path.dirname(mei_path))
|
||||
for j, res in enumerate(results):
|
||||
label = res["type"]
|
||||
fname = f"{sample_id}_f{font_idx}_{j}.png"
|
||||
out_path = os.path.join(OUTPUT_ROOT, label, fname)
|
||||
res["image"].save(out_path)
|
||||
counts[label] += 1
|
||||
|
||||
total_processed += 1
|
||||
|
||||
print(f"\n=== Done ===")
|
||||
print(f"Processed: {total_processed} MEI files across {len(fonts)} fonts")
|
||||
print(f"Errors: {errors}")
|
||||
print(f"Extracted crops:")
|
||||
for label, count in sorted(counts.items()):
|
||||
print(f" {label:10s}: {count:7d}")
|
||||
print(f"Output: {OUTPUT_ROOT}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
161
extract_fast.py
Normal file
161
extract_fast.py
Normal file
|
|
@ -0,0 +1,161 @@
|
|||
"""Fast accidental extraction from PrIMuS using Verovio + cairosvg.
|
||||
|
||||
Optimizations:
|
||||
- Pre-filter MEI files using agnostic encoding (skip files with no accidentals)
|
||||
- Single font pass (Leipzig) - augmentation can be done at training time
|
||||
- Flush output for progress monitoring
|
||||
"""
|
||||
|
||||
import glob
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
from io import BytesIO
|
||||
|
||||
import cairosvg
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
sys.stdout.reconfigure(encoding="utf-8")
|
||||
|
||||
DATASET_ROOT = r"C:\src\accidentals\dataset"
|
||||
OUTPUT_ROOT = r"C:\src\accidentals\crops"
|
||||
|
||||
GLYPH_MAP = {"E262": "sharp", "E260": "flat", "E261": "natural"}
|
||||
|
||||
ACCID_RE = re.compile(
|
||||
r'class="(keyAccid|accid)"[^>]*>\s*'
|
||||
r'<use xlink:href="#(E\d+)[^"]*"\s+'
|
||||
r'transform="translate\((\d+),\s*(\d+)\)\s*scale\(([^,]+),\s*([^)]+)\)"'
|
||||
)
|
||||
SVG_DIM_RE = re.compile(r'width="(\d+)px" height="(\d+)px"')
|
||||
|
||||
|
||||
def has_accidentals(agnostic_path: str) -> bool:
|
||||
"""Quick check if an agnostic file contains any accidental tokens."""
|
||||
with open(agnostic_path, "r", encoding="utf-8") as f:
|
||||
return "accidental." in f.read()
|
||||
|
||||
|
||||
def main():
|
||||
import verovio
|
||||
|
||||
# Glob all MEI files at once (much faster than iterating dirs on Windows)
|
||||
print("Finding MEI files...", flush=True)
|
||||
mei_files = sorted(
|
||||
glob.glob(os.path.join(DATASET_ROOT, "package_*", "*", "*.mei"))
|
||||
)
|
||||
print(f"Found {len(mei_files)} MEI files (skipping pre-filter)", flush=True)
|
||||
|
||||
# Create output directories
|
||||
for label in GLYPH_MAP.values():
|
||||
os.makedirs(os.path.join(OUTPUT_ROOT, label), exist_ok=True)
|
||||
|
||||
# Setup Verovio
|
||||
tk = verovio.toolkit()
|
||||
tk.setOptions({
|
||||
"adjustPageHeight": True,
|
||||
"adjustPageWidth": True,
|
||||
"header": "none",
|
||||
"footer": "none",
|
||||
"breaks": "none",
|
||||
"scale": 100,
|
||||
"pageMarginLeft": 0,
|
||||
"pageMarginRight": 0,
|
||||
"pageMarginTop": 0,
|
||||
"pageMarginBottom": 0,
|
||||
"font": "Leipzig",
|
||||
})
|
||||
|
||||
counts = {"sharp": 0, "flat": 0, "natural": 0}
|
||||
errors = 0
|
||||
t0 = time.time()
|
||||
|
||||
for i, mei_path in enumerate(mei_files):
|
||||
if i % 1000 == 0:
|
||||
elapsed = time.time() - t0
|
||||
rate = i / elapsed if elapsed > 0 else 0
|
||||
eta = (len(mei_files) - i) / rate if rate > 0 else 0
|
||||
print(
|
||||
f"[{i:6d}/{len(mei_files)}] "
|
||||
f"sharp={counts['sharp']} flat={counts['flat']} "
|
||||
f"natural={counts['natural']} "
|
||||
f"err={errors} "
|
||||
f"rate={rate:.0f}/s eta={eta/60:.0f}m",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
try:
|
||||
with open(mei_path, "r", encoding="utf-8") as f:
|
||||
mei_data = f.read()
|
||||
|
||||
tk.loadData(mei_data)
|
||||
svg = tk.renderToSVG(1)
|
||||
if not svg:
|
||||
continue
|
||||
|
||||
dim_match = SVG_DIM_RE.search(svg)
|
||||
if not dim_match:
|
||||
continue
|
||||
svg_w = int(dim_match.group(1))
|
||||
svg_h = int(dim_match.group(2))
|
||||
if svg_w == 0 or svg_h == 0:
|
||||
continue
|
||||
|
||||
# Find accidentals
|
||||
accids = []
|
||||
for match in ACCID_RE.finditer(svg):
|
||||
glyph_type = GLYPH_MAP.get(match.group(2))
|
||||
if glyph_type:
|
||||
accids.append({
|
||||
"type": glyph_type,
|
||||
"px": int(match.group(3)) / 10.0,
|
||||
"py": int(match.group(4)) / 10.0,
|
||||
})
|
||||
|
||||
if not accids:
|
||||
continue
|
||||
|
||||
# Rasterize
|
||||
png_bytes = cairosvg.svg2png(
|
||||
bytestring=svg.encode("utf-8"),
|
||||
output_width=svg_w,
|
||||
output_height=svg_h,
|
||||
)
|
||||
img_rgba = Image.open(BytesIO(png_bytes))
|
||||
white_bg = Image.new("RGBA", img_rgba.size, (255, 255, 255, 255))
|
||||
img = Image.alpha_composite(white_bg, img_rgba).convert("L")
|
||||
|
||||
sample_id = os.path.basename(os.path.dirname(mei_path))
|
||||
for j, acc in enumerate(accids):
|
||||
x1 = max(0, int(acc["px"] - 3))
|
||||
y1 = max(0, int(acc["py"] - 25))
|
||||
x2 = min(svg_w, int(acc["px"] + 22))
|
||||
y2 = min(svg_h, int(acc["py"] + 35))
|
||||
if x2 - x1 < 5 or y2 - y1 < 5:
|
||||
continue
|
||||
|
||||
crop = img.crop((x1, y1, x2, y2))
|
||||
fname = f"{sample_id}_{j}.png"
|
||||
crop.save(os.path.join(OUTPUT_ROOT, acc["type"], fname))
|
||||
counts[acc["type"]] += 1
|
||||
|
||||
except Exception:
|
||||
errors += 1
|
||||
if errors <= 5:
|
||||
traceback.print_exc()
|
||||
|
||||
elapsed = time.time() - t0
|
||||
print(f"\n=== Done in {elapsed/60:.1f} minutes ===", flush=True)
|
||||
print(f"Errors: {errors}")
|
||||
print(f"Crops extracted:")
|
||||
for label, count in sorted(counts.items()):
|
||||
print(f" {label:10s}: {count:7d}")
|
||||
print(f"Total: {sum(counts.values())}")
|
||||
print(f"Output: {OUTPUT_ROOT}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
BIN
model/accidental_classifier.pt
Normal file
BIN
model/accidental_classifier.pt
Normal file
Binary file not shown.
122
segment_test.py
Normal file
122
segment_test.py
Normal file
|
|
@ -0,0 +1,122 @@
|
|||
"""Test segmenting symbols from a PrIMuS image using connected components."""
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def segment_symbols(png_path: str):
|
||||
"""Segment a PrIMuS image into individual symbols."""
|
||||
# Load as grayscale
|
||||
img = cv2.imread(png_path, cv2.IMREAD_GRAYSCALE)
|
||||
if img is None:
|
||||
# PrIMuS uses palette mode PNGs, load via PIL first
|
||||
pil_img = Image.open(png_path).convert("L")
|
||||
img = np.array(pil_img)
|
||||
|
||||
h, w = img.shape
|
||||
print(f"Image size: {w} x {h}")
|
||||
|
||||
# Binarize (black ink on white background)
|
||||
_, binary = cv2.threshold(img, 128, 255, cv2.THRESH_BINARY_INV)
|
||||
|
||||
# Remove staff lines using horizontal morphology
|
||||
# Detect staff lines
|
||||
horizontal_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (w // 4, 1))
|
||||
staff_lines = cv2.morphologyEx(binary, cv2.MORPH_OPEN, horizontal_kernel)
|
||||
|
||||
# Remove staff lines from binary image
|
||||
no_lines = binary - staff_lines
|
||||
|
||||
# Clean up with small morphological close to reconnect broken symbols
|
||||
close_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (2, 2))
|
||||
no_lines = cv2.morphologyEx(no_lines, cv2.MORPH_CLOSE, close_kernel)
|
||||
|
||||
# Find connected components
|
||||
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(
|
||||
no_lines, connectivity=8
|
||||
)
|
||||
|
||||
# Filter out tiny noise (area < 10 pixels)
|
||||
symbols = []
|
||||
for i in range(1, num_labels): # skip background
|
||||
x, y, sw, sh, area = stats[i]
|
||||
if area < 10:
|
||||
continue
|
||||
symbols.append({
|
||||
"x": x, "y": y, "w": sw, "h": sh,
|
||||
"area": area, "cx": centroids[i][0],
|
||||
})
|
||||
|
||||
# Sort by x position (left to right)
|
||||
symbols.sort(key=lambda s: s["x"])
|
||||
|
||||
# Group nearby components into symbol clusters
|
||||
# (a single symbol like a sharp may have multiple disconnected parts)
|
||||
clusters = []
|
||||
for s in symbols:
|
||||
if clusters and s["x"] < clusters[-1]["x2"] + 5:
|
||||
# Merge into previous cluster
|
||||
c = clusters[-1]
|
||||
c["x1"] = min(c["x1"], s["x"])
|
||||
c["y1"] = min(c["y1"], s["y"])
|
||||
c["x2"] = max(c["x2"], s["x"] + s["w"])
|
||||
c["y2"] = max(c["y2"], s["y"] + s["h"])
|
||||
c["components"].append(s)
|
||||
else:
|
||||
clusters.append({
|
||||
"x1": s["x"], "y1": s["y"],
|
||||
"x2": s["x"] + s["w"], "y2": s["y"] + s["h"],
|
||||
"components": [s],
|
||||
})
|
||||
|
||||
print(f"Found {len(symbols)} components -> {len(clusters)} symbol clusters")
|
||||
for i, c in enumerate(clusters):
|
||||
w = c["x2"] - c["x1"]
|
||||
h = c["y2"] - c["y1"]
|
||||
print(f" Cluster {i:3d}: x={c['x1']:5d}-{c['x2']:5d} y={c['y1']:3d}-{c['y2']:3d} size={w:3d}x{h:3d}")
|
||||
|
||||
return clusters, img, no_lines
|
||||
|
||||
|
||||
def match_tokens_to_clusters(agnostic_path: str, clusters: list):
|
||||
"""Match agnostic tokens to segmented clusters."""
|
||||
with open(agnostic_path, "r") as f:
|
||||
tokens = f.read().strip().split("\t")
|
||||
|
||||
print(f"\nAgnostic tokens ({len(tokens)}):")
|
||||
for i, tok in enumerate(tokens):
|
||||
print(f" {i:3d}: {tok}")
|
||||
|
||||
print(f"\nClusters: {len(clusters)}")
|
||||
print(f"Tokens: {len(tokens)}")
|
||||
|
||||
return tokens
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sample_dir = r"C:\src\accidentals\dataset\package_aa\000141058-1_1_1"
|
||||
sample_id = "000141058-1_1_1"
|
||||
|
||||
png_path = f"{sample_dir}/{sample_id}.png"
|
||||
agnostic_path = f"{sample_dir}/{sample_id}.agnostic"
|
||||
|
||||
clusters, img, no_lines = segment_symbols(png_path)
|
||||
tokens = match_tokens_to_clusters(agnostic_path, clusters)
|
||||
|
||||
# Save debug images
|
||||
debug_dir = r"C:\src\accidentals\debug"
|
||||
import os
|
||||
os.makedirs(debug_dir, exist_ok=True)
|
||||
|
||||
cv2.imwrite(f"{debug_dir}/no_lines.png", 255 - no_lines)
|
||||
|
||||
# Draw bounding boxes on original
|
||||
vis = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
||||
for i, c in enumerate(clusters):
|
||||
color = (0, 0, 255)
|
||||
cv2.rectangle(vis, (c["x1"], c["y1"]), (c["x2"], c["y2"]), color, 1)
|
||||
cv2.putText(vis, str(i), (c["x1"], c["y1"] - 2),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.3, color, 1)
|
||||
cv2.imwrite(f"{debug_dir}/segmented.png", vis)
|
||||
print(f"\nDebug images saved to {debug_dir}/")
|
||||
315
train.py
Normal file
315
train.py
Normal file
|
|
@ -0,0 +1,315 @@
|
|||
"""Train a small CNN to classify accidentals (sharp/flat/natural).
|
||||
|
||||
Outputs an ONNX model for inference via OpenCV's dnn module.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import random
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
|
||||
sys.stdout.reconfigure(encoding="utf-8")
|
||||
|
||||
CROPS_ROOT = Path(r"C:\src\accidentals\crops")
|
||||
OUTPUT_DIR = Path(r"C:\src\accidentals\model")
|
||||
LABELS = ["flat", "natural", "sharp"] # alphabetical → class indices 0,1,2
|
||||
IMG_SIZE = (40, 40) # square, small enough for fast inference
|
||||
BATCH_SIZE = 256
|
||||
NUM_EPOCHS = 30
|
||||
LR = 1e-3
|
||||
SEED = 42
|
||||
VAL_SPLIT = 0.1
|
||||
TEST_SPLIT = 0.05
|
||||
MAX_PER_CLASS = 5000 # subsample: more than enough for this easy problem
|
||||
|
||||
|
||||
class InMemoryDataset(Dataset):
|
||||
"""Dataset that holds all images as tensors in RAM."""
|
||||
|
||||
def __init__(self, images: torch.Tensor, labels: torch.Tensor, augment=False):
|
||||
self.images = images # (N, 1, H, W) float32
|
||||
self.labels = labels # (N,) long
|
||||
self.augment = augment
|
||||
self.affine = transforms.RandomAffine(
|
||||
degrees=5, translate=(0.1, 0.1), scale=(0.9, 1.1)
|
||||
)
|
||||
self.erasing = transforms.RandomErasing(p=0.2, scale=(0.02, 0.08))
|
||||
|
||||
def __len__(self):
|
||||
return len(self.labels)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
img = self.images[idx]
|
||||
if self.augment:
|
||||
img = self.affine(img)
|
||||
img = self.erasing(img)
|
||||
return img, self.labels[idx]
|
||||
|
||||
|
||||
class AccidentalCNN(nn.Module):
|
||||
"""Small CNN for 3-class accidental classification.
|
||||
|
||||
Input: 1x40x40 grayscale image
|
||||
Architecture: 3 conv blocks + 1 FC layer
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.features = nn.Sequential(
|
||||
# Block 1: 1 -> 16 channels, 40x40 -> 20x20
|
||||
nn.Conv2d(1, 16, 3, padding=1),
|
||||
nn.BatchNorm2d(16),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.MaxPool2d(2),
|
||||
# Block 2: 16 -> 32 channels, 20x20 -> 10x10
|
||||
nn.Conv2d(16, 32, 3, padding=1),
|
||||
nn.BatchNorm2d(32),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.MaxPool2d(2),
|
||||
# Block 3: 32 -> 64 channels, 10x10 -> 5x5
|
||||
nn.Conv2d(32, 64, 3, padding=1),
|
||||
nn.BatchNorm2d(64),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.MaxPool2d(2),
|
||||
)
|
||||
self.classifier = nn.Sequential(
|
||||
nn.Dropout(0.3),
|
||||
nn.Linear(64 * 5 * 5, 64),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Dropout(0.3),
|
||||
nn.Linear(64, len(LABELS)),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.features(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
x = self.classifier(x)
|
||||
return x
|
||||
|
||||
|
||||
def load_all_into_ram():
|
||||
"""Load all images into RAM as tensors. Returns (images, labels) tensors."""
|
||||
print("Loading all images into RAM...", flush=True)
|
||||
all_images = []
|
||||
all_labels = []
|
||||
|
||||
for label_idx, label_name in enumerate(LABELS):
|
||||
label_dir = CROPS_ROOT / label_name
|
||||
files = [f for f in os.listdir(label_dir) if f.endswith(".png")]
|
||||
random.shuffle(files)
|
||||
if MAX_PER_CLASS and len(files) > MAX_PER_CLASS:
|
||||
files = files[:MAX_PER_CLASS]
|
||||
print(f" {label_name}: {len(files)} files", flush=True)
|
||||
|
||||
for fname in files:
|
||||
img = Image.open(label_dir / fname).convert("L")
|
||||
img = img.resize(IMG_SIZE, Image.BILINEAR)
|
||||
arr = np.array(img, dtype=np.float32) / 255.0
|
||||
all_images.append(arr)
|
||||
all_labels.append(label_idx)
|
||||
|
||||
# Stack into tensors: (N, 1, H, W)
|
||||
images = torch.tensor(np.array(all_images)).unsqueeze(1)
|
||||
labels = torch.tensor(all_labels, dtype=torch.long)
|
||||
print(f" Total: {len(labels)} images, tensor shape: {images.shape}", flush=True)
|
||||
print(f" Memory: {images.nbytes / 1024**2:.0f} MB", flush=True)
|
||||
return images, labels
|
||||
|
||||
|
||||
def split_data(images, labels):
|
||||
"""Stratified train/val/test split. Returns index arrays."""
|
||||
indices = torch.randperm(len(labels))
|
||||
|
||||
by_class = {i: [] for i in range(len(LABELS))}
|
||||
for idx in indices:
|
||||
by_class[labels[idx].item()].append(idx.item())
|
||||
|
||||
train_idx, val_idx, test_idx = [], [], []
|
||||
for cls, idxs in by_class.items():
|
||||
n = len(idxs)
|
||||
n_test = max(1, int(n * TEST_SPLIT))
|
||||
n_val = max(1, int(n * VAL_SPLIT))
|
||||
test_idx.extend(idxs[:n_test])
|
||||
val_idx.extend(idxs[n_test : n_test + n_val])
|
||||
train_idx.extend(idxs[n_test + n_val :])
|
||||
|
||||
return (
|
||||
torch.tensor(train_idx),
|
||||
torch.tensor(val_idx),
|
||||
torch.tensor(test_idx),
|
||||
)
|
||||
|
||||
|
||||
def make_weighted_sampler(labels):
|
||||
"""Create a sampler that oversamples minority classes."""
|
||||
class_counts = torch.bincount(labels, minlength=len(LABELS)).float()
|
||||
weights = 1.0 / class_counts[labels]
|
||||
return WeightedRandomSampler(weights, len(labels), replacement=True)
|
||||
|
||||
|
||||
def evaluate(model, loader, device):
|
||||
"""Evaluate model on a data loader."""
|
||||
model.eval()
|
||||
correct = 0
|
||||
total = 0
|
||||
class_correct = np.zeros(len(LABELS))
|
||||
class_total = np.zeros(len(LABELS))
|
||||
total_loss = 0.0
|
||||
n_batches = 0
|
||||
|
||||
with torch.no_grad():
|
||||
for imgs, lbls in loader:
|
||||
imgs, lbls = imgs.to(device), lbls.to(device)
|
||||
outputs = model(imgs)
|
||||
loss = F.cross_entropy(outputs, lbls)
|
||||
total_loss += loss.item()
|
||||
n_batches += 1
|
||||
|
||||
_, predicted = outputs.max(1)
|
||||
total += lbls.size(0)
|
||||
correct += predicted.eq(lbls).sum().item()
|
||||
|
||||
for i in range(len(LABELS)):
|
||||
mask = lbls == i
|
||||
class_total[i] += mask.sum().item()
|
||||
class_correct[i] += (predicted[mask] == i).sum().item()
|
||||
|
||||
acc = correct / total if total > 0 else 0
|
||||
class_acc = np.where(class_total > 0, class_correct / class_total, 0)
|
||||
avg_loss = total_loss / n_batches if n_batches > 0 else 0
|
||||
return avg_loss, acc, class_acc
|
||||
|
||||
|
||||
def main():
|
||||
random.seed(SEED)
|
||||
np.random.seed(SEED)
|
||||
torch.manual_seed(SEED)
|
||||
|
||||
device = torch.device("cpu")
|
||||
print(f"Device: {device}", flush=True)
|
||||
|
||||
# Load everything into RAM
|
||||
images, labels = load_all_into_ram()
|
||||
|
||||
# Split
|
||||
train_idx, val_idx, test_idx = split_data(images, labels)
|
||||
print(f"\nTrain: {len(train_idx)} Val: {len(val_idx)} Test: {len(test_idx)}", flush=True)
|
||||
for i, name in enumerate(LABELS):
|
||||
nt = (labels[train_idx] == i).sum().item()
|
||||
nv = (labels[val_idx] == i).sum().item()
|
||||
ne = (labels[test_idx] == i).sum().item()
|
||||
print(f" {name:8s}: train={nt:6d} val={nv:5d} test={ne:5d}", flush=True)
|
||||
|
||||
# Create datasets (in-memory, fast)
|
||||
train_ds = InMemoryDataset(images[train_idx], labels[train_idx], augment=True)
|
||||
val_ds = InMemoryDataset(images[val_idx], labels[val_idx], augment=False)
|
||||
test_ds = InMemoryDataset(images[test_idx], labels[test_idx], augment=False)
|
||||
|
||||
# Weighted sampler for class imbalance
|
||||
sampler = make_weighted_sampler(labels[train_idx])
|
||||
|
||||
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, sampler=sampler, num_workers=0)
|
||||
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
|
||||
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
|
||||
|
||||
# Model
|
||||
model = AccidentalCNN().to(device)
|
||||
param_count = sum(p.numel() for p in model.parameters())
|
||||
print(f"\nModel parameters: {param_count:,}", flush=True)
|
||||
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
|
||||
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
||||
optimizer, mode="max", factor=0.5, patience=3
|
||||
)
|
||||
|
||||
# Training loop
|
||||
best_val_acc = 0.0
|
||||
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
||||
best_model_path = OUTPUT_DIR / "accidental_classifier.pt"
|
||||
|
||||
print(f"\n{'Epoch':>5s} {'TrLoss':>7s} {'VaLoss':>7s} {'VaAcc':>6s} "
|
||||
f"{'Flat':>5s} {'Nat':>5s} {'Sharp':>5s} {'LR':>8s} {'Time':>5s}", flush=True)
|
||||
print("-" * 70, flush=True)
|
||||
|
||||
for epoch in range(1, NUM_EPOCHS + 1):
|
||||
t0 = time.time()
|
||||
model.train()
|
||||
train_loss = 0.0
|
||||
n_batches = 0
|
||||
|
||||
for imgs, lbls in train_loader:
|
||||
imgs, lbls = imgs.to(device), lbls.to(device)
|
||||
optimizer.zero_grad()
|
||||
outputs = model(imgs)
|
||||
loss = F.cross_entropy(outputs, lbls)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
train_loss += loss.item()
|
||||
n_batches += 1
|
||||
|
||||
avg_train_loss = train_loss / n_batches
|
||||
val_loss, val_acc, class_acc = evaluate(model, val_loader, device)
|
||||
scheduler.step(val_acc)
|
||||
|
||||
elapsed = time.time() - t0
|
||||
lr = optimizer.param_groups[0]["lr"]
|
||||
|
||||
print(
|
||||
f"{epoch:5d} {avg_train_loss:7.4f} {val_loss:7.4f} {val_acc:6.1%} "
|
||||
f"{class_acc[0]:5.1%} {class_acc[1]:5.1%} {class_acc[2]:5.1%} "
|
||||
f"{lr:.6f} {elapsed:5.1f}s",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
if val_acc > best_val_acc:
|
||||
best_val_acc = val_acc
|
||||
torch.save(model.state_dict(), best_model_path)
|
||||
|
||||
# Final evaluation on test set
|
||||
print(f"\n=== Test Set Evaluation ===", flush=True)
|
||||
model.load_state_dict(torch.load(best_model_path, weights_only=True))
|
||||
test_loss, test_acc, test_class_acc = evaluate(model, test_loader, device)
|
||||
print(f"Test accuracy: {test_acc:.1%}", flush=True)
|
||||
for i, name in enumerate(LABELS):
|
||||
print(f" {name:8s}: {test_class_acc[i]:.1%}", flush=True)
|
||||
|
||||
# Export to ONNX for OpenCV dnn inference
|
||||
print(f"\nExporting to ONNX...", flush=True)
|
||||
model.eval()
|
||||
dummy = torch.randn(1, 1, *IMG_SIZE)
|
||||
onnx_path = OUTPUT_DIR / "accidental_classifier.onnx"
|
||||
torch.onnx.export(
|
||||
model,
|
||||
dummy,
|
||||
str(onnx_path),
|
||||
input_names=["image"],
|
||||
output_names=["logits"],
|
||||
dynamic_axes={"image": {0: "batch"}, "logits": {0: "batch"}},
|
||||
opset_version=13,
|
||||
)
|
||||
print(f"Saved ONNX model: {onnx_path}", flush=True)
|
||||
print(f"Saved PyTorch model: {best_model_path}", flush=True)
|
||||
|
||||
# Quick ONNX verification via OpenCV
|
||||
import cv2
|
||||
net = cv2.dnn.readNetFromONNX(str(onnx_path))
|
||||
test_img = np.random.randint(0, 255, (IMG_SIZE[1], IMG_SIZE[0]), dtype=np.uint8)
|
||||
blob = cv2.dnn.blobFromImage(test_img, 1.0 / 255.0, IMG_SIZE, 0, swapRB=False)
|
||||
net.setInput(blob)
|
||||
out = net.forward()
|
||||
pred = LABELS[np.argmax(out[0])]
|
||||
print(f"ONNX+OpenCV smoke test: input=random noise -> pred={pred} (expected: arbitrary)", flush=True)
|
||||
print(f"\nDone. Labels: {LABELS}", flush=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
Reference in a new issue